In [1]:
from PIL import ImageFont, ImageDraw, Image
from fontTools.ttLib import TTFont

import numpy as np
import tensorflow as tf

In [2]:
VOCAB = 28; MAX_LEN = 5; MAX_DUP = 6
EBD_DIM = 256; UNIT_DIM = 128; BATCH_SIZE = 256

## Load Font

In [3]:
class Glyph(object):
    # transform character to bitmap
    def __init__(self, fonts, size=64):
        # load fonts, size. We will use 2 fonts for all CJK characters, so keep 2 codepoint books.
        self.codepoints = [set() for _ in fonts]
        self.size = int(size * 0.8)
        self.size_img = size
        self.pad = (size - self.size) // 2
        self.fonts = [ImageFont.truetype(f, self.size) for f in fonts]
        # use a cache to reduce computation if duplicated characters encountered.
        self.cache = {}
        for cp, font in zip(self.codepoints, fonts):
            font = TTFont(font)
            # store codepoints in font cmap into self.codepoints
            for cmap in font['cmap'].tables:
                if not cmap.isUnicode():
                    continue
                for k in cmap.cmap:
                    cp.add(k)
    
    def draw(self, ch):
        if ch in self.cache:
            return self.cache[ch]
        # search among fonts, use the first found
        exist = False
        for i in range(len(self.codepoints)):
            if ord(ch) in self.codepoints[i]:
                font = self.fonts[i]
                exist = True
                break
        if not exist:
            return None

        img = Image.new('L', (self.size_img, self.size_img), 0)
        draw = ImageDraw.Draw(img)
        (width, baseline), (offset_x, offset_y) = font.font.getsize(ch)
        draw.text((self.pad - offset_x, self.pad - offset_y + 4), ch, font=font, fill=255, stroke_fill=255) 
        img_array = np.array(img.getdata(), dtype='float32').reshape((self.size_img, self.size_img)) / 255
        self.cache[ch] = img_array

        return img_array

In [4]:
glyphbook = Glyph(['data/fonts/HanaMinA.otf', 'data/fonts/HanaMinB.otf'])

## Model

In [5]:
class Res_CNN(tf.keras.Model):
    def __init__(self, feature_dim, kernel_size):
        super(Res_CNN, self).__init__()
        self.cnn1 = tf.keras.layers.Convolution2D(feature_dim, kernel_size, padding='same')
        self.cnn2 = tf.keras.layers.Convolution2D(feature_dim, kernel_size, padding='same')
        self.cnn3 = tf.keras.layers.Convolution2D(feature_dim, kernel_size, padding='same')
        self.norm = tf.keras.layers.BatchNormalization()
        
    def call(self, x):
        x = self.cnn1(x)
        x_identity = tf.identity(x)
        x = self.cnn2(x)
        x_identity2 = tf.identity(x)
        x = self.cnn3(x + x_identity)
        x = self.norm(x + x_identity2)
        x = tf.nn.relu(x)
        return x

In [6]:
class CNN_Encoder(tf.keras.Model):
    # This is essentially a CNN layer, 
    def __init__(self, embedding_dim):
        super(CNN_Encoder, self).__init__()
        self.res_cnn1 = Res_CNN(embedding_dim // 16, (3, 3))
        self.pool1 = tf.keras.layers.MaxPool2D((2, 2))
        self.res_cnn2 = Res_CNN(embedding_dim // 4, (3, 3))
        self.pool2 = tf.keras.layers.MaxPool2D((2, 2))
        self.res_cnn3 = Res_CNN(embedding_dim, (3, 3))
        self.fc = tf.keras.layers.Dense(embedding_dim, activation='relu')

    def call(self, x, training=True):
        # x shape after cnn1 == (batch_size, 64, 64, embedding_dim // 16)
        x = self.res_cnn1(x)
        # x shape after pool1 == (batch_size, 32, 32, embedding_dim // 16)
        x = self.pool1(x)
        
        # x shape after cnn2 == (batch_size, 32, 32, embedding_dim // 4)
        x = self.res_cnn2(x)
        # x shape after pool2 == (batch_size, 16, 16, embedding_dim // 4)
        x = self.pool2(x)
        
        # x shape after cnn3 == (batch_size, 16, 16, embedding_dim)
        x = self.res_cnn3(x)
        # reshape from (batch_size, 16, 16, embedding_dim) to (batch_size, 256, embedding_dim)
        x = tf.reshape(x, [x.shape[0], -1, x.shape[-1]])
        # x shape after fc == (batch_size, 256, embedding_dim)
        if training:
            x = tf.nn.dropout(x, rate=0.4)
        x = self.fc(x)
        return x

In [7]:
class Bahdanau_Attention(tf.keras.Model):
    def __init__(self, attention_dim):
        super(Bahdanau_Attention, self).__init__()
        self.W1 = tf.keras.layers.Dense(attention_dim)
        self.W2 = tf.keras.layers.Dense(attention_dim)
        self.V = tf.keras.layers.Dense(1)

    def call(self, features, hidden):
        # features(CNN_Encoder output) shape == (batch_size, 256, embedding_dim)

        # hidden shape == (batch_size, hidden_size)
        # hidden_with_time_axis shape == (batch_size, 1, hidden_size)
        hidden_with_time_axis = tf.expand_dims(hidden, 1)

        # score shape == (batch_size, 256, attention_dim)
        score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))

        # attention_weights shape == (batch_size, 256, 1)
        # you get 1 at the last axis because you are applying score to self.V
        attention_weights = tf.nn.softmax(self.V(score), axis=1)

        # context_vector shape after sum == (batch_size, embedding_dim)
        context_vector = attention_weights * features
        context_vector = tf.reduce_sum(context_vector, axis=1)

        return context_vector, attention_weights

In [8]:
class Length_Decoder(tf.keras.Model):
    def __init__(self, max_length):
        super(Length_Decoder, self).__init__()
        self.pool = tf.keras.layers.MaxPool2D((2, 2))
        self.fc1 = tf.keras.layers.Dense(max_length * 16, activation='relu')
        self.fc2 = tf.keras.layers.Dense(max_length * 16, activation='relu')
        self.fc3 = tf.keras.layers.Dense(max_length * 4, activation='relu')
        self.fc4 = tf.keras.layers.Dense(max_length)
        
    def call(self, x, d_t=None, d_c=None):
        x = tf.reshape(x, (x.shape[0], 16, 16, x.shape[-1]))
        x = self.pool(x) # shape = (batch_size, 8, 8, embedding_dim)
        x = self.fc1(x)
        x = tf.reshape(x, (x.shape[0], -1))
        if d_t != None and d_c != None:
            d = tf.concat([tf.cast(d_t, 'float32'), tf.cast(d_c, 'float32')], axis=-1)
            x = tf.concat([d, x], axis=-1)
        x = self.fc2(x)
        x = self.fc3(x)
        x = self.fc4(x)
        # shape = (batch_size, max_length)
        return x

In [9]:
class RNN_Decoder(tf.keras.Model):
    def __init__(self, embedding_dim, hidden_size, vocab_size, max_length):
        super(RNN_Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.gru1 = tf.keras.layers.GRU(self.hidden_size, return_sequences=True,
                                        return_state=True, recurrent_initializer='glorot_uniform', dropout=0.3)
        self.gru2 = tf.keras.layers.GRU(self.hidden_size, return_sequences=True,
                                        return_state=True, recurrent_initializer='glorot_uniform')
        self.gru3 = tf.keras.layers.GRU(self.hidden_size, return_sequences=True,
                                        return_state=True, recurrent_initializer='glorot_uniform', dropout=0.3)
        self.fc1 = tf.keras.layers.Dense(hidden_size, activation='relu')
        self.fc2 = tf.keras.layers.Dense(vocab_size)

        self.attention = Bahdanau_Attention(hidden_size)

    def call(self, x, l, d_t, d_c, features, hidden, training=True, teacher_forcing=True):
        # x is forward direction, y is beckward direction
        # defining attention as a separate model
        l = tf.cast(l, 'float32')
        hidden_0_with_length = tf.concat([l, hidden[0]], axis=-1)
        context_vector, attention_weights = self.attention(features, hidden_0_with_length)
        l = tf.expand_dims(l, 1)
        d = tf.expand_dims(tf.concat([tf.cast(d_t, 'float32'), tf.cast(d_c, 'float32')], axis=-1), 1)

        # x shape before is (batch_size, 1) since it is passed through one by one at a time
        # x shape after passing through embedding == (batch_size, 1, embedding_dim)
        if teacher_forcing:
            x = self.embedding(x)
        else:
            if not self.embedding.built:
                self.embedding.build(x.shape)
            x = tf.tensordot(x, self.embedding.weights[0], axes=[-1,0])
        # context_vector shape is (batch_size, embedding_dim)
        # x shape after concatenation == (batch_size, 1, embedding_dim + embedding_dim)
        x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)

        # passing the concatenated vector to the GRU
        # x shape is (batch_size, 1, hidden_size)
        # state is new hidden used in next step
        x, state1 = self.gru1(x, initial_state = hidden[0], training=training)
        x_identity = tf.identity(x)
        x = tf.concat([d, l, x], axis=-1)
        x, state2 = self.gru2(x, initial_state = hidden[1], training=training)
        x_identity2 = tf.identity(x)
        x, state3 = self.gru3(x + x_identity, initial_state = hidden[2], training=training)
        # x shape (batch_size, 1, max_length + hidden_size)
        x = tf.concat([d, l, x + x_identity2], axis=-1)
        x = tf.reshape(x, (x.shape[0], -1))
        # x shape (batch_size, hidden_size)
        x = self.fc1(x)
        # x shape (batch_size, vocab_size)
        x = self.fc2(x)

        return x, [state1, state2, state3], attention_weights

    def reset_state(self, batch_size):
        # generate new hidden layer with different batch size
        return [tf.zeros((batch_size, self.hidden_size)) for _ in range(3)]

## Load Model

In [10]:
optimizer = tf.keras.optimizers.Adam()
optimizer_length = tf.keras.optimizers.Adam()
optimizer_dups = tf.keras.optimizers.Adam()

In [11]:
encoder = CNN_Encoder(embedding_dim = EBD_DIM)
length_decoder = Length_Decoder(max_length = MAX_LEN)
dup_decoder = Length_Decoder(max_length = MAX_DUP)
decoder = RNN_Decoder(embedding_dim=EBD_DIM, hidden_size=UNIT_DIM, max_length = MAX_LEN, vocab_size=VOCAB)

In [12]:
# use a checkpoint to store weights
checkpoint_path = "./checkpoints/train_step2"
ckpt = tf.train.Checkpoint(encoder=encoder, decoder=decoder, length_decoder=length_decoder,
    dup_decoder=dup_decoder, optimizer=optimizer, optimizer_length=optimizer_length, optimizer_dups=optimizer_dups)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
ckpt.restore(ckpt_manager.latest_checkpoint)

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x14c374c90>

## Testing

In [13]:
def predict(features, max_length, length, total_dups, curr_dups):
    # start with 0
    dec_input = tf.constant([[[1] + [0] * (VOCAB - 1)]] * features.shape[0], dtype='float32')
    hidden = decoder.reset_state(batch_size=features.shape[0])
    probability = tf.convert_to_tensor([1]*features.shape[0], dtype='float32')
    # iterate predictions, no teacher forcing here
    for i in range(max_length):
        prediction, hidden, attention_weights = decoder(
            tf.expand_dims(dec_input[:, i, :], 1), length, total_dups, curr_dups, features, hidden, training=False, teacher_forcing=False)
        # we need deterministic result
        prediction = tf.math.softmax(prediction, axis=-1)
        probability *= tf.math.reduce_max(prediction, axis=-1)
        dec_input = tf.concat([dec_input, tf.expand_dims(prediction, 1)], axis=1)
    return tf.math.argmax(dec_input, axis=-1), probability

In [14]:
@tf.function
def test(glyph):
    features = encoder(glyph, training=False)
    total_dups = tf.nn.softmax(dup_decoder(features), axis=-1)
    dups_dict = tf.math.argmax(total_dups, axis=-1)
    max_dup = tf.math.reduce_max(dups_dict) + 1
    
    results = tf.zeros((glyph.shape[0], max_dup, MAX_LEN + 1), dtype='int64')
    probs = tf.zeros((glyph.shape[0], max_dup), dtype='float32')
    identity_matrix = tf.convert_to_tensor(np.identity(MAX_DUP), dtype='int64')
    
    for i in range(max_dup):
        curr_dups = tf.math.minimum(tf.math.argmax(total_dups, axis=-1), i)
        curr_dups = tf.nn.embedding_lookup(identity_matrix, curr_dups)
        length = tf.nn.softmax(length_decoder(features, total_dups, curr_dups), axis=-1)
        test_result, prob = predict(features, MAX_LEN, length, total_dups, curr_dups)
        results = tf.concat([results[:, :i, :], tf.expand_dims(test_result, axis=1), tf.zeros((glyph.shape[0], max_dup - i - 1, MAX_LEN + 1), dtype='int64')], axis=1)
        probs = tf.concat([probs[:, :i], tf.expand_dims(prob, axis=1), tf.zeros((glyph.shape[0], max_dup - i - 1), dtype='float32')], axis=1)
    return results, probs, dups_dict

In [15]:
def evaluate(word):
    test_input = []
    for char in word:
        glyph = glyphbook.draw(char)
        if glyph is not None:
            test_input.append(glyph)
        else:
            raise ValueError('Character {} unsupported.'.format(char))
    test_input = np.expand_dims(test_input, -1)
    
    def decode(indexes):
        code = ''
        for i in indexes:
            if i <= 0:
                continue
            elif i >= 27:
                break
            else:
                code += chr(i + 96)
        return code
    
    results, probs, dups_dict = test(test_input)
    results = results.numpy()
    probs = probs.numpy()
    dups_dict = dups_dict.numpy()
    
    final_result = []
    for i in range(results.shape[0]):
        final_result.append([])
        for j in range(results.shape[1]):
            if j <= dups_dict[i]:
                final_result[-1].append([decode(results[i, j, :]), probs[i, j]])
    
    return final_result

In [16]:
evaluate('日月金木水火土的戈十大中一弓人心手口尸廿山女田止卜片')

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


[[['a', 0.78253275]],
 [['b', 0.9760856]],
 [['c', 0.8359161]],
 [['d', 0.964399]],
 [['e', 0.9135183]],
 [['f', 0.5961462]],
 [['g', 0.56496906]],
 [['h', 0.9249467]],
 [['i', 0.9969825]],
 [['j', 0.9766994]],
 [['k', 0.87347305]],
 [['l', 0.9695252]],
 [['m', 0.88853776]],
 [['n', 0.94106185]],
 [['o', 0.53135574]],
 [['i', 0.33893174]],
 [['q', 0.9779192]],
 [['au', 0.44066703]],
 [['s', 0.8695538]],
 [['t', 0.9750777]],
 [['u', 0.7820365]],
 [['v', 0.72862536]],
 [['w', 0.9937593]],
 [['x', 0.93955576]],
 [['y', 0.92141175]],
 [['llml', 0.32722926], ['llms', 0.84263015]]]