In [1]:
import tensorflow as tf
import numpy as np
import string
import pathlib
import os

from tensorflow.keras.models import Model
from tensorflow.keras.layers import GRU, Embedding, Input
from tensorflow.keras.layers.experimental.preprocessing import StringLookup
from tensorflow.keras.applications import EfficientNetB1
from tensorflow.io.gfile import GFile
from tensorflow.strings import unicode_split
from tensorflow.data import Dataset, TextLineDataset

In [None]:
efficient_net = EfficientNetB1(input_shape=(100,100,3), include_top=False, weights=None)

In [None]:
vocab_size = 84
embedding_size = 256
sequence_length = 100

inputs = Input((sequence_length))

outs = Embedding(input_dim=vocab_size, output_dim=embedding_size, input_length=sequence_length)(inputs)

# outs = GRU()(inputs)

svg_net = Model(inputs=inputs,outputs=outs)

svg_net.summary()

In [2]:
class DataManager:
    def __init__(self, log_dir):
        self.log_dir = log_dir
        self.START_TOKEN = '[SOS]'
        self.END_TOKEN = '[EOS]'
        self.vocab = list(sorted(set(string.printable))) + [self.START_TOKEN, self.END_TOKEN]
        self.chars_to_ids = StringLookup(vocabulary=self.vocab)

    def load_dataset(self):
        ds = TextLineDataset(str(pathlib.Path(self.log_dir, 'file_names.txt')))
        ds = ds.take(5)
        ds = ds.map(self.parse_svg_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        ds = ds.padded_batch(2, drop_remainder=True)
        
        return ds
            
    def parse_svg_img(self, file_name):
        svg_path = tf.strings.join([self.log_dir, '/svgs/', file_name, '.svg'])
        img_path = tf.strings.join([self.log_dir, '/imgs/', file_name, '.png'])

        svg = tf.io.read_file(svg_path)
        svg = tf.concat([[self.START_TOKEN], unicode_split(svg, 'UTF-8'), [self.END_TOKEN]], axis=0)
        svg = self.chars_to_ids(svg)
        
        img = tf.io.read_file(img_path)
        img = tf.io.decode_png(img, channels=3)
        img = tf.cast(img, tf.float32)
        img = img / 255.0
        
        return svg, img

dm = DataManager('dataset')

ds = dm.load_dataset()

for _ in ds:
    print(_[0].shape, _[1].shape)

(2, 132) (2, 100, 100, 3)
(2, 132) (2, 100, 100, 3)
