In [1]:
import json
import os
import random

In [2]:
import GPUtil
import tensorflow as tf
from bert_serving.client import ConcurrentBertClient
from tensorflow.python.estimator.canned.dnn import DNNClassifier
from tensorflow.python.estimator.run_config import RunConfig
from tensorflow.python.estimator.training import TrainSpec, EvalSpec, train_and_evaluate

In [3]:
train_fp = ['/home/anna/Downloads/final_all_data/first_stage/train.json']
eval_fp = ['/home/anna/Downloads/final_all_data/first_stage/test.json']

In [4]:
batch_size = 128
num_parallel_calls = 1
num_concurrent_clients = num_parallel_calls * 2  # should be at least greater than `num_parallel_calls`

In [5]:
bc = ConcurrentBertClient()

In [6]:
# hardcoded law_ids
laws = [184, 336, 314, 351, 224, 132, 158, 128, 223, 308, 341, 349, 382, 238, 369, 248, 266, 313, 127, 340, 288, 172,
        209, 243, 302, 200, 227, 155, 147, 143, 261, 124, 359, 343, 291, 241, 235, 367, 393, 274, 240, 269, 199, 119,
        246, 282, 133, 177, 170, 310, 364, 201, 312, 244, 357, 233, 236, 264, 225, 234, 328, 417, 151, 135, 136, 348,
        217, 168, 134, 237, 262, 150, 114, 196, 303, 191, 392, 226, 267, 272, 212, 353, 315, 205, 372, 215, 350, 275,
        385, 164, 338, 292, 159, 162, 333, 388, 356, 375, 326, 402, 397, 125, 395, 290, 176, 354, 185, 141, 279, 399,
        192, 383, 307, 295, 361, 286, 404, 390, 294, 115, 344, 268, 171, 117, 273, 193, 418, 220, 198, 231, 386, 363,
        346, 210, 270, 144, 347, 280, 281, 118, 122, 116, 360, 239, 228, 305, 130, 152, 389, 276, 213, 186, 413, 285,
        316, 245, 232, 175, 149, 263, 387, 283, 391, 211, 396, 352, 345, 258, 253, 163, 140, 293, 194, 342, 161, 358,
        271, 156, 260, 384, 153, 277, 214]

In [7]:
laws_str = [str(x) for x in laws]

In [8]:
def get_encodes(x):
    # x is `batch_size` of lines, each of which is a json object
    samples = [json.loads(l) for l in x]
    text = [s['fact'][:50] + s['fact'][-50:] for s in samples]
    features = bc.encode(text)
    # randomly choose a label
    labels = [[str(random.choice(s['meta']['relevant_articles']))] for s in samples]
    return features, labels

In [9]:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
run_config = RunConfig(model_dir='/home/anna/Downloads/final_all_data/first_stage/law-model',
                       session_config=config,
                       save_checkpoints_steps=1000)


In [10]:
estimator = DNNClassifier(
    hidden_units=[512],
    feature_columns=[tf.feature_column.numeric_column('feature', shape=(768,))],
    n_classes=len(laws),
    config=run_config,
    label_vocabulary=laws_str,
    dropout=0.1)

INFO:tensorflow:Using config: {'_model_dir': '/home/anna/Downloads/final_all_data/first_stage/law-model', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 1000, '_save_checkpoints_secs': None, '_session_config': gpu_options {
  allow_growth: true
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fd23423bf98>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


In [11]:
input_fn = lambda fp: (tf.data.TextLineDataset(fp)
                       .apply(tf.contrib.data.shuffle_and_repeat(buffer_size=10000))
                       .batch(batch_size)
                       .map(lambda x: tf.py_func(get_encodes, [x], [tf.float32, tf.string], name='bert_client'),
                            num_parallel_calls=num_parallel_calls)
                       .map(lambda x, y: ({'feature': x}, y))
                       .prefetch(20))

In [12]:
train_spec = TrainSpec(input_fn=lambda: input_fn(train_fp))

In [13]:
eval_spec = EvalSpec(input_fn=lambda: input_fn(eval_fp), throttle_secs=0)

In [14]:
train_and_evaluate(estimator, train_spec, eval_spec)

INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps 1000 or save_checkpoints_secs None.
Instructions for updating:
Use `tf.data.experimental.shuffle_and_repeat(...)`.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /home/anna/Downloads/final_all_data/first_stage/law-model/model.ckpt-0
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into /home/anna/Downloads/final_all_data/first_stage/law-model/model.ckpt.
INFO:tensorflow:loss = 626.38776, step = 1
INFO:tensorflow:global_step/sec: 0.209274
INFO:tensorflow:loss = 14

KeyboardInterrupt: 