In [1]:
import numpy as np
import cairosvg
import tensorflow as tf
import matplotlib.pyplot as plt
import time
import pathlib

In [2]:
class DatasetGenerator:
    def __init__(self, log_dir):
        self.log_dir = log_dir

        pathlib.Path(log_dir, 'svgs').mkdir(parents=True, exist_ok=True)
        pathlib.Path(log_dir, 'imgs').mkdir(parents=True, exist_ok=True)

    def generate_lines(self, count):
        file_names = []

        # Generate random coordinates
        all_coords = np.random.randint(0, 101, size=(count, 2, 2))

        for coords in all_coords:
            start_x = coords[0][0]
            start_y = coords[0][1]
            end_x = coords[1][0]
            end_y = coords[1][1]
            
            # Build the svg
            svg = self.build_svg(start_x, start_y, end_x, end_y)

            timestamp = str(time.time()).replace('.', '')

            # Write the svg and image to disk
            self.save_svg(svg, timestamp)
            self.save_img(svg, timestamp)

            file_names.append(f'simple_line-{timestamp}')
        
        np.savetxt(pathlib.Path(self.log_dir, 'file_names.csv'), file_names, fmt='%s', delimiter=',')
            

    def build_svg(self, start_x, start_y, end_x, end_y):
        assert 0 <= start_x <= 100
        assert 0 <= start_y <= 100
        assert 0 <= end_x <= 100
        assert 0 <= end_y <= 100

        svg_begin = '<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100">' + '\n'
        path = f'\t<path fill="none" stroke="#ffffff" d="M{start_x},{start_y} L{end_x},{end_y}" />' + '\n'
        svg_end = '</svg>' + '\n'

        svg = ''.join([svg_begin, path, svg_end])
        
        return svg
        
    def save_svg(self, svg, timestamp):
        svg_path = pathlib.Path(self.log_dir, 'svgs', f'simple_line-{timestamp}.svg')
        with open(svg_path, 'w') as f:
            f.write(svg)

    def save_img(self, svg, timestamp):
        image_path = str(pathlib.Path(self.log_dir, 'imgs', f'simple_line-{timestamp}.png'))
        cairosvg.svg2png(bytestring=svg.encode('UTF-8'), write_to=image_path)

In [3]:
dg = DatasetGenerator(log_dir='dataset')

print('Starting!')

start_time = time.time()
dg.generate_lines(100000)
end_time = time.time() - start_time

print('Done! - Took: ', end_time, 's')

Starting!
Done! - Took:  82.8606379032135 s
