# dataset

> End-to-end functions taking in centerline-stroke SVG's and outputting deltas in Stroke-3 format.

In [None]:
#| default_exp dataset

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import os

import numpy as np
import pandas as pd

from singleline_stroke3.display import *
from singleline_stroke3.strokes import *
from singleline_stroke3.svg import *
from singleline_stroke3.transforms import *

## Bulk Processing SVG's into Stroke-3

In [None]:
#| export
def enumerate_files(input_dir):
    """Find all the files within a directory (non-recursively)"""
    files = []
    for file in os.listdir(input_dir):
        if os.path.isfile(os.path.join(input_dir, file)):
            files.append(file)
    return files

In [None]:
#| export
def svgs_to_deltas(
    input_dir,
    output_dir=None,
    target_size=200,
    total_n=1000,
    min_n=3,
    epsilon=1.0,
    limit=None,
):
    if output_dir:
        svg_dir = os.path.join(output_dir, "svg")
        png_dir = os.path.join(output_dir, "png")
        for d in [svg_dir, png_dir]:
            if not os.path.isdir(d):
                os.makedirs(d)

    all_files = enumerate_files(input_dir)
    print(f"found {len(all_files)} in {input_dir}")
    dataset = []
    for i, fname in enumerate(all_files):
        if limit and i > limit:
            break
        input_fname = os.path.join(input_dir, fname)

        try:
            rescaled_strokes = svg_to_strokes(input_fname, total_n=total_n, min_n=min_n)

            joined_strokes, _ = merge_until(rescaled_strokes, dist_threshold=15.0)
            spliced_strokes, _ = splice_until(joined_strokes, dist_threshold=40.0)

            print(
                f"{fname}: {len(rescaled_strokes)} strokes -> {len(joined_strokes)} joined -> {len(spliced_strokes)} spliced"
            )

            deltas = stroke_rdp_deltas(spliced_strokes, epsilon=epsilon)
            dataset.append(deltas)

            # monitor number of points before/after applying RDP path simplification algorithm
            raw_points = np.vstack(rescaled_strokes).shape[0]
            rdp_points = deltas.shape[0]
            print(f"{input_fname} points: raw={raw_points}, rdp={rdp_points}")

            if output_dir:

                def new_suffix(subdir, fname, suffix):
                    sd = os.path.join(output_dir, subdir)
                    if not os.path.isdir(sd):
                        os.makedirs(sd)
                    return os.path.join(sd, fname.replace(".svg", suffix))

                final_n_strokes = len(spliced_strokes)
                subdir = f"png/{final_n_strokes:02d}"
                plot_strokes(
                    rescaled_strokes, fname=new_suffix(subdir, fname, ".0_strokes.png")
                )
                plot_strokes(
                    joined_strokes, fname=new_suffix(subdir, fname, ".1_joined.png")
                )
                plot_strokes(
                    spliced_strokes, fname=new_suffix(subdir, fname, ".2_spliced.png")
                )
                plot_strokes(
                    deltas_to_strokes(deltas),
                    fname=new_suffix(subdir, fname, ".3_deltas.png"),
                )

                # raw_output_fname = new_suffix('svg', fname, ".raw.svg")
                # with open(raw_output_fname, "w", encoding="utf-8") as raw_out:
                #     raw_dwg = render_strokes(rescaled_strokes, target_size=target_size)
                #     raw_dwg.write(raw_out, pretty=True)
                #     print(f"\twrote {raw_output_fname}")

                # preproc_output_fname = new_suffix('svg', fname, ".preproc.svg")
                # with open(preproc_output_fname, "w", encoding="utf-8") as preproc_out:
                #     preproc_dwg = render_deltas(deltas, target_size=target_size)
                #     preproc_dwg.save(preproc_output_fname)
                #     print(f"\twrote {preproc_output_fname}")
        except Exception as e:
            print(f"error processing idx={i} input_fname={input_fname}: {e}")
            # raise e
    return np.array(dataset, dtype=object)

In [None]:
# input_dir = '../data/svg/'
# output_dir = '../outputs'

# # debug: only run for the first 10 files
# limit = 10

# _ = svgs_to_deltas(input_dir, output_dir, limit=limit)

In [None]:
# partial_dataset = svgs_to_deltas(input_dir, output_dir, limit=None)

In [None]:
# len(partial_dataset)
# np.savez('../outputs/subset.npz', partial_dataset, encoding='latin1', allow_pickle=True)

**Sidebar:** Visualizing all the images in the dataset (up to N strokes)

In [None]:
# from moviepy.editor import *

# imgs_01 = sorted(enumerate_files("../outputs_segmented/png/01/"))
# abs_01 = [os.path.join("../outputs_segmented/png/01", f) for f in imgs_01]
# imgs_02 = sorted(enumerate_files("../outputs_segmented/png/02/"))
# abs_02 = [os.path.join("../outputs_segmented/png/02", f) for f in imgs_02]
# imgs_03 = sorted(enumerate_files("../outputs_segmented/png/03/"))
# abs_03 = [os.path.join("../outputs_segmented/png/03", f) for f in imgs_03]
# imgs_04 = sorted(enumerate_files("../outputs_segmented/png/04/"))
# abs_04 = [os.path.join("../outputs_segmented/png/04", f) for f in imgs_04]

# all_fnames = abs_01 + abs_02 + abs_03 + abs_04

# new_clip = ImageSequenceClip(all_fnames, fps=20)
# new_clip.write_videofile("new_file_fps20.mp4")

# new_clip = ImageSequenceClip(all_fnames, fps=24)
# new_clip.write_videofile("new_file_fps24.mp4")

## Dataset Filtering

In [None]:
#| export
def stroke_summary_df(dataset):
    summary = [
        {
            "idx": i,
            "num_points": len(deltas),
            "num_strokes": len(deltas_to_strokes(deltas)),
        }
        for i, deltas in enumerate(dataset)
    ]
    # by_num_strokes = sorted(summary, key=lambda k: k["num_strokes"], reverse=True)
    df = pd.DataFrame(summary)
    return df

In [None]:
#| export
def split_train_val(full_dataset, output_fname, split_ratio=0.8, max_strokes=None, max_points=None, min_points=None):
    full_df = stroke_summary_df(full_dataset)
    _df = full_df
    if max_strokes:
        _df = full_df[full_df.num_strokes <= max_strokes]
    if max_points:
        _df = full_df[full_df.num_points <= max_points]
    if min_points:
        _df = full_df[full_df.num_points <= min_points]

    shuffled_df = _df.sample(frac=1)

    train_size = int(len(shuffled_df) * split_ratio / 100) * 100
    val_size = len(shuffled_df) - train_size
    print(train_size, val_size, len(shuffled_df))

    shuffled_df_train = full_dataset[list(shuffled_df[:train_size].idx)]
    shuffled_df_val = full_dataset[list(shuffled_df[train_size:].idx)]
    print(len(shuffled_df_train), len(shuffled_df_val))

    np.savez(
        output_fname,
        train=shuffled_df_train,
        valid=shuffled_df_val,
        test=shuffled_df_val,
        encoding="latin1",
        allow_pickle=True,
    )

In [None]:
#| hide
import nbdev

nbdev.nbdev_export()