Skip to content

Commit

Permalink
[Refactor] Implement attention-based rnn enc-dec.
Browse files Browse the repository at this point in the history
  • Loading branch information
DevinZ1993 committed Jun 23, 2018
1 parent b444733 commit d7082ce
Show file tree
Hide file tree
Showing 11 changed files with 297 additions and 136 deletions.
9 changes: 5 additions & 4 deletions char2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
# -*- coding:utf-8 -*-

from char_dict import CharDict
from check_file import char2vec_path, file_uptodate
from gensim import models
from numpy.random import uniform
from paths import char2vec_path, check_uptodate
from poems import Poems
from singleton import Singleton
from utils import *
from utils import CHAR_VEC_DIM
import numpy as np
import os


def _gen_char2vec():
Expand All @@ -26,7 +27,7 @@ def _gen_char2vec():
class Char2Vec(Singleton):

def __init__(self):
if not file_uptodate(char2vec_path):
if not check_uptodate(char2vec_path):
_gen_char2vec()
self.embedding = np.load(char2vec_path)
self.char_dict = CharDict()
Expand All @@ -35,7 +36,7 @@ def get_embedding(self):
return self.embedding

def get_vect(self, ch):
return self.char2vec[self.char2int(ch)]
return self.embedding[self.char_dict.char2int(ch)]

def get_vects(self, text):
return np.stack(map(self.get_vect, text)) if len(text) > 0 \
Expand Down
6 changes: 3 additions & 3 deletions char_dict.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#! /usr/bin/env python3
#-*- coding:utf-8 -*-

from check_file import char_dict_path, file_uptodate
from paths import raw_dir, char_dict_path, check_uptodate
from singleton import Singleton
from utils import *
from utils import is_cn_char
import os


Expand Down Expand Up @@ -38,7 +38,7 @@ def _gen_char_dict():
class CharDict(Singleton):

def __init__(self):
if not file_uptodate(char_dict_path):
if not check_uptodate(char_dict_path):
_gen_char_dict()
self._int2char = []
self._char2int = dict()
Expand Down
26 changes: 6 additions & 20 deletions data_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
#! /usr/bin/env python3
#-*- coding:utf-8 -*-

from check_file import gen_data_path, plan_data_path, file_uptodate
from paths import gen_data_path, plan_data_path, check_uptodate
from poems import Poems
from rank_words import RankedWords
from segment import Segmenter
from utils import *
import argparse
import re
import subprocess

Expand All @@ -15,6 +13,7 @@ def gen_train_data():
print("Generating training data ...")
segmenter = Segmenter()
poems = Poems()
poems.shuffle()
ranked_words = RankedWords()
plan_data = []
gen_data = []
Expand Down Expand Up @@ -52,7 +51,7 @@ def gen_train_data():

def batch_train_data(batch_size):
""" Training data generator for the poem generator."""
if not file_uptodate(gen_data_path):
if not check_uptodate(gen_data_path):
print("Warning: training data is not found!")
gen_train_data()
keywords = []
Expand All @@ -70,24 +69,11 @@ def batch_train_data(batch_size):
contexts.clear()
sentences.clear()
if len(keywords) > 0:
yield keywords, contets, sentences
yield keywords, contexts, sentences


if __name__ == '__main__':
parser = argparse.ArgumentParser(description = 'Training data generation.')
parser.add_argument('--clean', dest = 'clean', default = False,
action = 'store_true', help = 'clean all processed data')
args = parser.parse_args()
if args.clean:
for f in os.listdir(data_dir):
if not re.match('raw', f):
print("Delete %s." % os.path.join(data_dir, f))
os.remove(os.path.join(data_dir, f))
subprocess.run(args=["./char2vec.py"], check = True,
stdout = sys.stdout)
subprocess.run(args=["./rank_words.py"], check = True,
stdout = sys.stdout)
if not file_uptodate(plan_data_path) or \
not file_uptodate(gen_data_path):
if not check_uptodate(plan_data_path) or \
not check_uptodate(gen_data_path):
gen_train_data()

Loading

0 comments on commit d7082ce

Please sign in to comment.