In [1]:
from PIL import ImageFont, ImageDraw, Image
from fontTools.ttLib import TTFont
import tensorflow as tf
import numpy as np

## Predictor Class

In [2]:
class Glyph(object):
    # transform character to bitmap
    def __init__(self, fonts, size=64):
        # load fonts, size. We will use 2 fonts for all CJK characters, so keep 2 codepoint books.
        self.codepoints = [set() for _ in fonts]
        self.size = int(size * 0.9)
        self.size_img = size
        self.pad = (size - self.size) // 2
        self.fonts = [ImageFont.truetype(f, self.size) for f in fonts]
        # use a cache to reduce computation if duplicated characters encountered.
        self.cache = {}
        for cp, font in zip(self.codepoints, fonts):
            font = TTFont(font)
            # store codepoints in font cmap into self.codepoints
            for cmap in font['cmap'].tables:
                if not cmap.isUnicode():
                    continue
                for k in cmap.cmap:
                    cp.add(k)
    
    def draw(self, ch):
        if ch in self.cache:
            return self.cache[ch]
        # search among fonts, use the first found
        exist = False
        for i in range(len(self.codepoints)):
            if ord(ch) in self.codepoints[i]:
                font = self.fonts[i]
                exist = True
                break
        if not exist:
            return None

        img = Image.new('L', (self.size_img, self.size_img), 0)
        draw = ImageDraw.Draw(img)
        (width, baseline), (offset_x, offset_y) = font.font.getsize(ch)
        draw.text((self.pad - offset_x, self.pad - offset_y + 4), ch, font=font, fill=255, stroke_fill=255) 
        img_array = np.array(img.getdata(), dtype='uint8').reshape((self.size_img, self.size_img))
        self.cache[ch] = img_array

        return img_array

In [3]:
class Predictor:
    def __init__(self, interpretor, img_size=128):
        self.glyph_generator = Glyph(['data/fonts/TH-Tshyn-P0.ttf',
                                      'data/fonts/TH-Tshyn-P1.ttf',
                                      'data/fonts/TH-Tshyn-P2.ttf'], size=img_size)
        self.interpretor = interpretor
        self.input = interpretor.get_input_details()
        self.output = interpretor.get_output_details()
    
    def generate_glyphs(self, chars):
        glyphs = []
        for char in chars:
            glyph = self.glyph_generator.draw(char)
            if glyph is not None:
                glyphs.append(glyph)
            else:
                print(f'Error: cannot draw {char}')
                return None
        return np.array(glyphs, dtype='uint8')

    def encode(self, glyphs, max_len=32):
        codes = np.ones((glyphs.shape[0], 1), dtype='uint8')
        for i in range(max_len):
            self.interpretor.resize_tensor_input(self.input[0]['index'], codes.shape)
            self.interpretor.resize_tensor_input(self.input[1]['index'], glyphs.shape)
            self.interpretor.allocate_tensors()
            self.interpretor.set_tensor(self.input[0]['index'], codes)
            self.interpretor.set_tensor(self.input[1]['index'], glyphs)
            self.interpretor.invoke()
            new_code = self.interpretor.get_tensor(self.output[0]['index'])
            new_code = np.argmax(new_code, axis=-1).astype('uint8')
            codes = np.concatenate([codes, new_code[:,-1,np.newaxis]], axis=-1)
            if np.cumprod(codes != 2, axis=-1).sum(axis=0)[-1] == 0:
                break
        return codes
    
    def decode(self, codes):
        cangjie_codes = []
        for i in range(codes.shape[0]):
            string = ''
            for w in codes[i]:
                if w == 3:
                    string += ','
                elif w > 3:
                    string += chr(w+93)
                elif w == 2:
                    break
            cangjie_codes.append(string)
        return cangjie_codes
    
    def __call__(self, chars, max_len=32):
        glyphs = self.generate_glyphs(chars)
        codes = self.encode(glyphs, max_len=max_len)
        cangjie_codes = self.decode(codes)
        return cangjie_codes

## Load Model and Predict

In [4]:
cangjie = tf.lite.Interpreter(model_path='cangjie_pruned.tflite')

Metal device set to: Apple M1 Max


INFO: Created TensorFlow Lite delegate for select TF ops.
2023-03-30 19:59:36.657699: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-03-30 19:59:36.657848: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
INFO: TfLiteFlexDelegate delegate: 10 nodes delegated out of 1277 nodes with 6 partitions.



In [5]:
pred = Predictor(cangjie)

In [6]:
pred('天下爲公')

['mk', 'my', 'bhnf', 'ci']