Skip to content

Commit

Permalink
Adding the ability to "prime" the network with an author's style (#4)
Browse files Browse the repository at this point in the history
* Adding the priming option for generation
* Fix comments for something more readable
  • Loading branch information
kristofbc authored and Grzego committed Dec 17, 2017
1 parent cd0fe03 commit c2c9704
Show file tree
Hide file tree
Showing 10 changed files with 55 additions and 20 deletions.
Binary file added data/styles.pkl
Binary file not shown.
75 changes: 55 additions & 20 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
parser = argparse.ArgumentParser()
parser.add_argument('--model', dest='model_path', type=str, default=os.path.join('pretrained', 'model-29'))
parser.add_argument('--text', dest='text', type=str, default=None)
parser.add_argument('--style', dest='style', type=int, default=None)
parser.add_argument('--bias', dest='bias', type=float, default=1.)
parser.add_argument('--force', dest='force', action='store_true', default=False)
parser.add_argument('--animation', dest='animation', action='store_true', default=False)
Expand Down Expand Up @@ -47,47 +48,70 @@ def cumsum(points):
return np.concatenate([sums, points[:, 2:]], axis=1)


def sample_text(sess, args_text, translation):
def sample_text(sess, args_text, translation, style=None):
fields = ['coordinates', 'sequence', 'bias', 'e', 'pi', 'mu1', 'mu2', 'std1', 'std2',
'rho', 'window', 'kappa', 'phi', 'finish', 'zero_states']
vs = namedtuple('Params', fields)(
*[tf.get_collection(name)[0] for name in fields]
)

text = np.array([translation.get(c, 0) for c in args_text])
sequence = np.eye(len(translation), dtype=np.float32)[text]
sequence = np.expand_dims(np.concatenate([sequence, np.zeros((1, len(translation)))]), axis=0)

coord = np.array([0., 0., 1.])
coords = [coord]

# Prime the model with the author style if requested
prime_len = 0
if style is not None:
# Priming consist of joining to a real pen-position and character sequences the synthetic sequence to generate
# and set the synthetic pen-position to a null vector (the positions are sampled from the MDN)
style_coords, style_text = style
prime_len = len(style_coords)
coords = list(style_coords)
coord = coords[0] # Set the first pen stroke as the first element to process
text = np.r_[style_text, text] # concatenate on 1 axis the prime text + synthesis character sequence
sequence_prime = np.eye(len(translation), dtype=np.float32)[style_text]
sequence_prime = np.expand_dims(np.concatenate([sequence_prime, np.zeros((1, len(translation)))]), axis=0)

sequence = np.eye(len(translation), dtype=np.float32)[text]
sequence = np.expand_dims(np.concatenate([sequence, np.zeros((1, len(translation)))]), axis=0)

phi_data, window_data, kappa_data, stroke_data = [], [], [], []
sess.run(vs.zero_states)
for s in range(1, 60 * len(args_text) + 1):
print('\r[{:5d}] sampling...'.format(s), end='')
is_priming = False
if s < prime_len:
is_priming = True

print('\r[{:5d}] sampling... {}'.format(s, 'priming' if is_priming else 'synthesis'))

e, pi, mu1, mu2, std1, std2, rho, \
finish, phi, window, kappa = sess.run([vs.e, vs.pi, vs.mu1, vs.mu2,
vs.std1, vs.std2, vs.rho, vs.finish,
vs.phi, vs.window, vs.kappa],
feed_dict={
vs.coordinates: coord[None, None, ...],
vs.sequence: sequence,
vs.sequence: sequence_prime if is_priming else sequence,
vs.bias: args.bias
})

phi_data += [phi[0, :]]
window_data += [window[0, :]]
kappa_data += [kappa[0, :]]
# ---
g = np.random.choice(np.arange(pi.shape[1]), p=pi[0])
coord = sample(e[0, 0], mu1[0, g], mu2[0, g],
std1[0, g], std2[0, g], rho[0, g])
coords += [coord]
stroke_data += [[mu1[0, g], mu2[0, g], std1[0, g], std2[0, g], rho[0, g], coord[2]]]

if not args.force and finish[0, 0] > 0.8:
print('\nFinished sampling!\n')
break
if is_priming:
# Use the real coordinate if priming
coord = coords[s]
else:
# Synthesis mode
phi_data += [phi[0, :]]
window_data += [window[0, :]]
kappa_data += [kappa[0, :]]
# ---
g = np.random.choice(np.arange(pi.shape[1]), p=pi[0])
coord = sample(e[0, 0], mu1[0, g], mu2[0, g],
std1[0, g], std2[0, g], rho[0, g])
coords += [coord]
stroke_data += [[mu1[0, g], mu2[0, g], std1[0, g], std2[0, g], rho[0, g], coord[2]]]

if not args.force and finish[0, 0] > 0.8:
print('\nFinished sampling!\n')
break

coords = np.array(coords)
coords[-1, 2] = 1.
Expand Down Expand Up @@ -115,7 +139,18 @@ def main():
else:
args_text = input('What to generate: ')

phi_data, window_data, kappa_data, stroke_data, coords = sample_text(sess, args_text, translation)
style = None
if args.style is not None:
style = None
with open(os.path.join('data', 'styles.pkl'), 'rb') as file:
styles = pickle.load(file)

if args.style > len(styles[0]):
raise ValueError('Requested style is not in style list')

style = [styles[0][args.style], styles[1][args.style]]

phi_data, window_data, kappa_data, stroke_data, coords = sample_text(sess, args_text, translation, style)

strokes = np.array(stroke_data)
epsilon = 1e-8
Expand Down
Binary file added imgs/style0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/style1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/style2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/style3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/style4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/style5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/style6.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/style7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

2 comments on commit c2c9704

@srini1948
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the strokes.xml how do you create a styles.pkl file?

Thanks.

@Grzego
Copy link
Owner

@Grzego Grzego commented on c2c9704 Jun 2, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@srini1948 Styles (if I recall correctly) are in this case just samples from the original dataset. So that they can be used to prime a network and generate sequence that is written in similar style.

Answering your question. After generating the dataset from all strokes.xml files, you can pick some of the lines that best represent a given style and use it later to prime the network. styles.pkl is a file that contains those selected samples.

Speaking more of the format of the styles. You can load the styles.pkl file as follows:
data = pickle.load(open('styles.pkl', 'rb'))

data have 3 elements (that each hold 8 arrays, because there is that many styles currently; so x below is between 0 and 7):

  • data[0][x] - stores numpy arrays of dimensions [N, 3], N stands for sequence length. Each row holds 3 numbers, [delta_x, delta_y, end_of_stroke]. (delta_x, delta_y) are just differences in coordinates of pen (ex. (0, 0.5) means we move pen 0.5 to the right. The end_of_stroke, if set to 1, means that after drawing line to that point we lift the pen, so that the next line won't be shown).
  • data[1][x] - is a standard python array of length N. This holds indices of letters that represent a text sequence. This should be obtained by using translation mapping from translate.pkl (which stores a python dictionary that maps letters to indices).
  • data[2][x] - another standard python array that holds just text labels.

Please sign in to comment.