# 自动配置参数，帮助自动生成运行文件

In [1]:
import tensorflow as tf

def load_hparams(model_dir):
  """
  Load hparams from an existing model directory.
  """
  hparams_file = os.path.join(model_dir, "hparams")
  if tf.gfile.Exists(hparams_file):
    print_out("# Loading hparams from %s" % hparams_file)
    with codecs.getreader("utf-8")(tf.gfile.GFile(hparams_file, "rb")) as f:
      try:
        hparams_values = json.load(f)
        hparams = tf.contrib.training.HParams(**hparams_values)
      except ValueError:
        print_out("  can't load hparams file")
        return None
    return hparams
  else:
    return None

def load_or_update_configs(config_path, default_dict=dict()):
    """
    Load configs from an existing config file
    """
    configs = default_dict
    try:
        with open(config_path,'r+',encoding='utf-8') as fin:
            print("Load config file: %s" % config_path)
            lines = fin.readlines()
            for line in lines:
                line = line.strip("\n").strip(" ")
                items = line.split('=')
                if line[0] == "#":
                    continue
                elif len(items) != 2:
                    print("Bad line: %s" % line)
                else:
                    key = items[0].strip()
                    value = items[1].strip()
                    configs[key] = value
                    print("%s=%s" % (key,value))
    except Exception:
        return None
    return configs

# Word-Level or single_level

In [2]:
def generate_eval_bash(model_id, infer_beam,out_dir,vocab_prefix,inference_input_file,inference_output_file,log_path):
    pattern = """python3 -m nrm.nrm  --infer_beam_width=%d --out_dir=%s --vocab_prefix=%s --inference_input_file=%s --inference_output_file=%s >> %s
    """ % (infer_beam,out_dir,vocab_prefix,inference_input_file,inference_output_file,log_path)
    with open("../eval_%s.sh" % model_id,'w+',encoding='utf-8') as fout:
        fout.write(pattern)
    return pattern
        
def generate_score_bash(model_id,ref_path,trans_path,out_path,metrics):
    pattern = """python3 -m nrm.utils.evaluation_utils %s %s %s %s %s
    """ % (model_id,ref_path,trans_path,out_path,metrics)
    with open("../score_%s.sh" % model_id,'w+',encoding='utf-8') as fout:
        fout.write(pattern)
    return pattern



def generate_word_level_model(data_space,model_id,gpu='0',language='chinese',vocab_prefix='vocab',train_prefix='train',test_prefix='test',dev_prefix='dev',preset_configs = dict()):
    config = {
        'vocab_prefix': data_space + vocab_prefix,
        'train_prefix': data_space + train_prefix,
        'test_prefix': data_space + test_prefix,
        'dev_prefix': data_space + dev_prefix,
        'out_dir' : "models/" + model_id ,
    }
    config = load_or_update_configs('configs/basic.config',config)
    if language == 'english':
        config = load_or_update_configs('configs/en_wl_offset.config',config)
    if language == 'chinese_char':
        config = load_or_update_configs('configs/cn_charlevel_offset.config',config)
    if language == 'chinese_bpe':
        config = load_or_update_configs('configs/cn_subword_offset.config',config)
    if language == 'english_bpe':
        config = load_or_update_configs('configs/en_subword_offset.config',config)
    if language == 'english_char':
        config = load_or_update_configs('configs/en_charlevel_offset.config',config)
    if language == 'en_cnnencdec':
        config = load_or_update_configs('configs/en_cnnencdec.config',config)
    if language == 'cnnencdec':
        config = load_or_update_configs('configs/cnnencdec.config',config)
    if language == 'cn_hl':
        config = load_or_update_configs('configs/cn_hl.config',config)
    if language == 'en_hl':
        config = load_or_update_configs('configs/en_hl.config',config)
    
    # preset
    for key in preset_configs.keys():
        config[key] = preset_configs[key]
        print("preset: %s=%s" % (key,preset_configs[key]))
        
        
    # Evaluation Scripts
    # model_id, infer_beam,out_dir,vocab_prefix,inference_input_file,inference_output_file,log_path
    infer_beam = 10
    out_dir=config['out_dir']
    vocab_prefix=config['vocab_prefix']
    inference_input_file = config['test_prefix'] + '.message'
    inference_output_file = 'infer_test/'+model_id+".test.txt"
    log_path= 'infer_test/log/'+model_id+".test.txt"
    eval_command = generate_eval_bash(model_id, infer_beam,out_dir,vocab_prefix,inference_input_file,inference_output_file,log_path)
    score_path= 'infer_test/scores/'+model_id+".test.txt"
    ref_path = config['test_prefix'] + '.response'
    score_command = generate_score_bash(model_id,ref_path,inference_output_file,score_path,config['metrics'])
    
    with open("../%s.sh" % model_id,'w+',encoding='utf-8') as fin:
        print("../%s.sh" % model_id)
        #out = "export CUDA_VISIBLE_DEVICES=%s \n\n" % gpu
        out = ""
        out += "python3 -m nrm.nrm  " 
        for key in config:
            out += "    --%s=%s  " % (key, config[key])
        out += "   >> logs/%s.txt \n" % model_id
        # print(out)
        fin.write(out)
        fin.write('\n%s\n%s\n' % (eval_command,score_command))
    
    

In [5]:
generate_word_level_model('/ldev/tensorflow/nmt2/nmt/data/enwordlevel/','enword_lstm',language='english',vocab_prefix="vocab.40000.separate")
generate_word_level_model('/ldev/tensorflow/nmt2/nmt/data/wordlevel/','chinese_lstm',vocab_prefix="vocab.40000.separate")
generate_word_level_model('/ldev/tensorflow/nmt2/nmt/data/charspace/','chinese_char_lstm',language='chinese_char',vocab_prefix="vocab.40000.separate")
generate_word_level_model('/home/mebiuw/nmt/data/encharspace/','english_char_lstm',language='english_char',vocab_prefix="vocab.40000.separate")
generate_word_level_model('/ldev/tensorflow/nmt2/nmt/data/bpelevel/','bpe_lstm',language='chinese_bpe',vocab_prefix="vocab.40000.separate")
generate_word_level_model('/ldev/tensorflow/nmt2/nmt/data/enbpelevel/','enbpe_lstm',language='english_bpe',vocab_prefix="vocab.40000.separate")
generate_word_level_model('/home/mebiuw/nmt/data/encharspace/','encnn',language='en_cnnencdec',vocab_prefix="vocab.40000.separate")
generate_word_level_model('/ldev/tensorflow/nmt2/nmt/data/charspace/','cnnencdec',language='cnnencdec',vocab_prefix="vocab.40000.separate")


generate_word_level_model('/ldev/tensorflow/nmt2/nmt/data/enhllevel/','en_hl',language='en_hl',vocab_prefix="vocab.40000",train_prefix='train.40000',test_prefix='test.40000',dev_prefix='dev.40000')
#generate_word_level_model('/home/mebiuw/nmt/data/enhllevel/','en_hl',language='en_hl',vocab_prefix="vocab.40000",train_prefix='train.40000',test_prefix='test.40000',dev_prefix='dev.40000')
generate_word_level_model('/ldev/tensorflow/nmt2/nmt/data/hllevel/','cn_hl',language='cn_hl',vocab_prefix="vocab.40000",train_prefix='train.40000',test_prefix='test.40000',dev_prefix='dev.40000')
preset_configs = {
    'src_embed_type':'rl1_cnn_segment'
}
generate_word_level_model('/ldev/tensorflow/nmt2/nmt/data/hllevel/','cn_hl1',preset_configs=preset_configs,language='cn_hl',vocab_prefix="vocab.40000",train_prefix='train.40000',test_prefix='test.40000',dev_prefix='dev.40000')
generate_word_level_model('/home/mebiuw/nmt/data/enhllevel/','en_hl1',preset_configs=preset_configs,language='en_hl',vocab_prefix="vocab.40000",train_prefix='train.40000',test_prefix='test.40000',dev_prefix='dev.40000')

preset_configs = {
    'src_embed_type':'rl2_cnn_segment'
}
generate_word_level_model('/ldev/tensorflow/nmt2/nmt/data/hllevel/','cn_hl2',preset_configs=preset_configs,language='cn_hl',vocab_prefix="vocab.40000",train_prefix='train.40000',test_prefix='test.40000',dev_prefix='dev.40000')
generate_word_level_model('/home/mebiuw/nmt/data/enhllevel/','en_hl2',preset_configs=preset_configs,language='en_hl',vocab_prefix="vocab.40000",train_prefix='train.40000',test_prefix='test.40000',dev_prefix='dev.40000')

preset_configs = {
    'src_embed_type':'rnn_segment'
}
generate_word_level_model('/ldev/tensorflow/nmt2/nmt/data/hllevel/','cn_rnn_hl',preset_configs=preset_configs,language='cn_hl',vocab_prefix="vocab.40000",train_prefix='train.40000',test_prefix='test.40000',dev_prefix='dev.40000')
generate_word_level_model('/home/mebiuw/nmt/data/enhllevel/','en_rnn_hl',preset_configs=preset_configs,language='en_hl',vocab_prefix="vocab.40000",train_prefix='train.40000',test_prefix='test.40000',dev_prefix='dev.40000')
preset_configs = {
    'src_embed_type':'rl1_rnn_segment'
}
generate_word_level_model('/ldev/tensorflow/nmt2/nmt/data/hllevel/','cn_rnn_hl1',preset_configs=preset_configs,language='cn_hl',vocab_prefix="vocab.40000",train_prefix='train.40000',test_prefix='test.40000',dev_prefix='dev.40000')
generate_word_level_model('/home/mebiuw/nmt/data/enhllevel/','en_rnn_hl1',preset_configs=preset_configs,language='en_hl',vocab_prefix="vocab.40000",train_prefix='train.40000',test_prefix='test.40000',dev_prefix='dev.40000')
preset_configs = {
    'src_embed_type':'rl1_rnn_segment_attention'
}
generate_word_level_model('/ldev/tensorflow/nmt2/nmt/data/hllevel/','cn_rnn_hlat',preset_configs=preset_configs,language='cn_hl',vocab_prefix="vocab.40000",train_prefix='train.40000',test_prefix='test.40000',dev_prefix='dev.40000')
generate_word_level_model('/home/mebiuw/nmt/data/enhllevel/','en_rnn_hlat',preset_configs=preset_configs,language='en_hl',vocab_prefix="vocab.40000",train_prefix='train.40000',test_prefix='test.40000',dev_prefix='dev.40000')
preset_configs = {
    'src_embed_type':'rl1_rnn_segment_attention'
}
generate_word_level_model('/ldev/tensorflow/nmt2/nmt/data/allevel/','cn_albpe',preset_configs=preset_configs,language='cn_hl',vocab_prefix="vocab.40000",train_prefix='train.40000',test_prefix='test.40000',dev_prefix='dev.40000')
generate_word_level_model('/ldev/tensorflow/nmt2/nmt/data/enallevel/','en_albpe',preset_configs=preset_configs,language='en_hl',vocab_prefix="vocab.40000",train_prefix='train.40000',test_prefix='test.40000',dev_prefix='dev.40000')
preset_configs = {
    'src_embed_type':'rl1_rnn_segment_attentionv2'
}
generate_word_level_model('/ldev/tensorflow/nmt2/nmt/data/allevel/','cn_albpe2',preset_configs=preset_configs,language='cn_hl',vocab_prefix="vocab.40000",train_prefix='train.40000',test_prefix='test.40000',dev_prefix='dev.40000')
generate_word_level_model('/ldev/tensorflow/nmt2/nmt/data/enallevel/','en_albpe2',preset_configs=preset_configs,language='en_hl',vocab_prefix="vocab.40000",train_prefix='train.40000',test_prefix='test.40000',dev_prefix='dev.40000')
preset_configs = {
    'src_embed_type':'rl1_rnn_segment_attentionv2_highway'
}
generate_word_level_model('/ldev/tensorflow/nmt2/nmt/data/hllevel/','cn_alhw',preset_configs=preset_configs,language='cn_hl',vocab_prefix="vocab.40000",train_prefix='train.40000',test_prefix='test.40000',dev_prefix='dev.40000')
generate_word_level_model('/ldev/tensorflow/nmt2/nmt/data/enhllevel/','en_alhw',preset_configs=preset_configs,language='en_hl',vocab_prefix="vocab.40000",train_prefix='train.40000',test_prefix='test.40000',dev_prefix='dev.40000')


Load config file: configs/basic.config
num_units=160
embed_dim=512
num_layers=4
unit_type=lstm
share_vocab=False
src_max_len=20
tgt_max_len=20
batch_size=256
encoder_type=bi
infer_batch_size=10
attention=luong
src=message
tgt=response
num_train_steps=1000000
steps_per_stats=100
metrics=rouge
Load config file: configs/en_wl_offset.config
src_max_len=30
tgt_max_len=30
../enword_lstm.sh
Load config file: configs/basic.config
num_units=160
embed_dim=512
num_layers=4
unit_type=lstm
share_vocab=False
src_max_len=20
tgt_max_len=20
batch_size=256
encoder_type=bi
infer_batch_size=10
attention=luong
src=message
tgt=response
num_train_steps=1000000
steps_per_stats=100
metrics=rouge
../chinese_lstm.sh
Load config file: configs/basic.config
num_units=160
embed_dim=512
num_layers=4
unit_type=lstm
share_vocab=False
src_max_len=20
tgt_max_len=20
batch_size=256
encoder_type=bi
infer_batch_size=10
attention=luong
src=message
tgt=response
num_train_steps=1000000
steps_per_stats=100
metrics=rouge
Load con

Load config file: configs/en_hl.config
src_max_len=35
tgt_max_len=50
src_embed_type=cnn_segment
charcnn_high_way_layer=4
charcnn_high_way_type=uniform
charcnn_max_window_size=8
charcnn_min_window_size=1
charcnn_filters_per_windows=100
high_way_layer=4
charcnn_relu=relu
seg_embed_dim=160
seg_len=10
seg_embed_mode=separate
metrics=rouge@hybrid
flexible_charcnn_windows=1/50-2/100-3/150-4/200-5/200-7/200-8/200
preset: src_embed_type=rl1_rnn_segment
../en_rnn_hl1.sh
Load config file: configs/basic.config
num_units=160
embed_dim=512
num_layers=4
unit_type=lstm
share_vocab=False
src_max_len=20
tgt_max_len=20
batch_size=256
encoder_type=bi
infer_batch_size=10
attention=luong
src=message
tgt=response
num_train_steps=1000000
steps_per_stats=100
metrics=rouge
Load config file: configs/cn_hl.config
src_max_len=30
tgt_max_len=30
src_embed_type=cnn_segment
charcnn_high_way_layer=4
charcnn_high_way_type=uniform
charcnn_max_window_size=3
charcnn_min_window_size=1
charcnn_filters_per_windows=200
high_w