forked from HackerPoet/YouTubeCommenter
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Generate.py
160 lines (143 loc) · 5.51 KB
/
Generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os, random, json
import numpy as np
from scipy import stats
import util
SEQ_SIZE = 8
NUM_TO_GEN = 20
MODEL_DIR = 'trained_all/'
PARSED_DIR = 'parsed_all/'
MAKE_STATEFUL = False
IS_REVERSE = False
#Load titles
title_words, title_word_to_ix = util.load_title_dict(PARSED_DIR)
title_dict_size = len(title_words)
title_sentences = util.load_title_sentences(PARSED_DIR)
#Load comments
comment_words, comment_word_to_ix = util.load_comment_dict(PARSED_DIR)
comment_dict_size = len(comment_words)
comment_sentences = util.load_comment_sentences(PARSED_DIR)
assert(len(title_sentences) == len(comment_sentences))
def word_ixs_to_str(word_ixs, is_title):
result_txt = ""
for w_ix in word_ixs:
w = (title_words if is_title else comment_words)[w_ix]
if len(result_txt) == 0 or w in ['.', ',', "'", '!', '?', ':', ';', '...']:
result_txt += w
elif len(result_txt) > 0 and result_txt[-1] == "'" and w in ['s', 're', 't', 'll', 've', 'd']:
result_txt += w
else:
result_txt += ' ' + w
if len(result_txt) > 0:
result_txt = result_txt[:1].upper() + result_txt[1:]
return result_txt
def probs_to_word_ix(pk, is_first):
if is_first:
pk[0] = 0.0
pk /= np.sum(pk)
else:
pk *= pk
pk /= np.sum(pk)
#for i in range(3):
# max_val = np.amax(pk)
# if max_val > 0.5:
# break
# pk *= pk
# pk /= np.sum(pk)
xk = np.arange(pk.shape[0], dtype=np.int32)
custm = stats.rv_discrete(name='custm', values=(xk, pk))
return custm.rvs()
def pred_text(model, context, max_len=64):
output = []
context = np.expand_dims(context, axis=0)
if MAKE_STATEFUL:
past_sample = np.zeros((1,), dtype=np.int32)
else:
past_sample = np.zeros((SEQ_SIZE,), dtype=np.int32)
while len(output) < max_len:
pk = model.predict([context, np.expand_dims(past_sample, axis=0)], batch_size=1)[-1]
if MAKE_STATEFUL:
pk = pk[0]
else:
past_sample = np.roll(past_sample, 1 if IS_REVERSE else -1)
new_sample = probs_to_word_ix(pk, len(output) == 0)
past_sample[0 if IS_REVERSE else -1] = new_sample
if new_sample == 0:
break
output.append(new_sample)
model.reset_states()
return output
#Load Keras and Theano
print("Loading Keras...")
import os, math
os.environ['KERAS_BACKEND'] = "tensorflow"
import tensorflow as tf
print("Tensorflow Version: " + tf.__version__)
import keras
print("Keras Version: " + keras.__version__)
from keras.layers import Input, Dense, Activation, Dropout, Flatten, Reshape, RepeatVector, TimeDistributed, concatenate
from keras.layers.convolutional import Conv2D, Conv2DTranspose, UpSampling2D, Convolution1D
from keras.layers.embeddings import Embedding
from keras.layers.local import LocallyConnected2D
from keras.layers.pooling import MaxPooling2D
from keras.layers.noise import GaussianNoise
from keras.layers.normalization import BatchNormalization
from keras.layers.recurrent import LSTM, SimpleRNN, GRU
from keras.models import Model, Sequential, load_model, model_from_json
from keras.optimizers import Adam, RMSprop, SGD
from keras.preprocessing.image import ImageDataGenerator
from keras.regularizers import l1
from keras.utils import plot_model, to_categorical
from keras import backend as K
K.set_image_data_format('channels_first')
#Fix bug with sparse_categorical_accuracy
from tensorflow.python.ops import math_ops
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
from tensorflow.python.ops import array_ops
def new_sparse_categorical_accuracy(y_true, y_pred):
y_pred_rank = ops.convert_to_tensor(y_pred).get_shape().ndims
y_true_rank = ops.convert_to_tensor(y_true).get_shape().ndims
# If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
if (y_true_rank is not None) and (y_pred_rank is not None) and (len(K.int_shape(y_true)) == len(K.int_shape(y_pred))):
y_true = array_ops.squeeze(y_true, [-1])
y_pred = math_ops.argmax(y_pred, axis=-1)
# If the predicted output and actual output types don't match, force cast them
# to match.
if K.dtype(y_pred) != K.dtype(y_true):
y_pred = math_ops.cast(y_pred, K.dtype(y_true))
return math_ops.cast(math_ops.equal(y_true, y_pred), K.floatx())
#Load the model
print("Loading Model...")
model = load_model(MODEL_DIR + 'model.h5', custom_objects={'new_sparse_categorical_accuracy':new_sparse_categorical_accuracy})
if MAKE_STATEFUL:
weights = model.get_weights()
model_json = json.loads(model.to_json())
layers = model_json['config']['layers']
for layer in layers:
if 'batch_input_shape' in layer['config']:
layer['config']['batch_input_shape'][0] = 1
if layer['config']['batch_input_shape'][1] == SEQ_SIZE:
layer['config']['batch_input_shape'][1] = 1
if layer['class_name'] == 'Embedding':
layer['config']['input_length'] = 1
if layer['class_name'] == 'RepeatVector':
layer['config']['n'] = 1
if layer['class_name'] == 'LSTM':
assert(layer['config']['stateful'] == False)
layer['config']['stateful'] = True
print(json.dumps(model_json, indent=4, sort_keys=True))
model = model_from_json(json.dumps(model_json))
model.set_weights(weights)
#plot_model(model, to_file='temp.png', show_shapes=True)
def generate_titles(my_title):
my_title = util.clean_text(my_title)
my_words = my_title.split(' ')
print(' '.join((w.upper() if w in title_word_to_ix else w) for w in my_words) + '\n')
my_title_ixs = [title_word_to_ix[w] for w in my_words if w in title_word_to_ix]
my_title_sample = util.bag_of_words(my_title_ixs, title_dict_size)
for i in range(10):
print(' ' + word_ixs_to_str(pred_text(model, my_title_sample), False))
print('')
while True:
my_title = input('Enter Title:\n')
generate_titles(my_title)