In [19]:
from PIL import Image
import torch.nn as  nn
import torchvision.transforms as standard_transforms
import numpy as np
from tqdm import tqdm

from datasets import Dataset, DatasetDict

In [13]:
id_to_chart_type = {
    -1: "other",
    1: "line",
    2: "scatter",
    4: "bar",
    7: "heat_map",
    9: "box-plot",
    10: "bubble",
    13: "sankey",
    14: "chord",
    15: "radial",
    16: "area",
    18: "donut",
    19: "choropleth",
    22: "treemap",
    29: "pie",
    31: "stream_graph",
    33: "hexabin",
    35: "graph",
    37: "parallel_coordinates",
    38: "sunburst",
    39: "waffle",
    40: "voronoi",
    41: "word_cloud",
    60: "contour",
    61: "filled-line",
    62: "scattergeo"
}

old_id_to_new = {
    -1: 1,
    1: 1,
    2: 2,
    4: 3,
    7: 4,
    9: 5,
    10: 6,
    13: 7,
    14: 8,
    15: 9,
    16: 10,
    18: 11,
    19: 12,
    22: 13,
    29: 14,
    31: 15,
    33: 16,
    35: 17,
    37: 18,
    38: 19,
    39: 20,
    40: 21,
    41: 22,
    60: 23,
    61: 24,
    62: 25
}


chart_type_to_id = {v: k for k, v in id_to_chart_type.items()}

In [11]:
mean_std = ( [.485, .456, .406], [.229, .224, .225])
fig_class_trasform = standard_transforms.Compose([
    standard_transforms.Resize((128, 128), interpolation=Image.ANTIALIAS),
    standard_transforms.ToTensor(),
    standard_transforms.Normalize(*mean_std)
])

  standard_transforms.Resize((128, 128), interpolation=Image.ANTIALIAS),


In [21]:
data_dirs = ["./svg_datasets/plotly_export", "./svg_datasets/chartblocks", "./svg_datasets/d3_clean", "./svg_datasets/graphiq_clean", "./svg_datasets/fusion_clean"]
#["./svg_datasets/plotly_export", "./svg_datasets/chartblocks", "./svg_datasets/d3_clean", "./svg_datasets/graphiq_clean"]
X = []
y = []
for charts_dir in data_dirs:
    labels_path = f"{charts_dir}/urls.txt"
    id_to_label = {}

    with open(labels_path) as labels_file:
        for label_line in labels_file.readlines():
            label_split = label_line.split(" ")
            sample_id = label_split[0]
            sample_label = label_split[2].split(",")[0]
            if "plotly" in charts_dir:
                id_to_label[sample_id] = chart_type_to_id[sample_label.replace("\n", "")]
            else:
                id_to_label[sample_id] = int(sample_label.replace("\n", ""))

    err_count = 0
    processed_count = 0
    for chart_id, chart_label in tqdm(id_to_label.items()):
        img_path = f"{charts_dir}/images/{chart_id}.png"
        #print(svg_path)
        try:
            image = Image.open(img_path).convert('RGB')
            #svg_feature_arr = fig_class_trasform(image).numpy().flatten()
            X.append(image)
            y.append(old_id_to_new[chart_label])
            #processed_count += 1
            #if processed_count > 1000:
            #    break
        except Exception as err:
            err_count += 1

    print(f"Percentage of docs with parsing errors: {err_count / len(id_to_label)}")

100%|██████████| 15232/15232 [01:16<00:00, 197.92it/s]


Percentage of docs with parsing errors: 0.0008534663865546219


100%|██████████| 22557/22557 [01:41<00:00, 222.69it/s]


Percentage of docs with parsing errors: 0.0


100%|██████████| 1440/1440 [00:12<00:00, 113.14it/s]


Percentage of docs with parsing errors: 0.14305555555555555


100%|██████████| 2733/2733 [00:07<00:00, 342.74it/s]


Percentage of docs with parsing errors: 0.0


100%|██████████| 697/697 [00:02<00:00, 345.23it/s]

Percentage of docs with parsing errors: 0.0





In [22]:
dev_pct = 0.1
test_pct = 0.2
shuffled_idxs = np.random.choice(len(y), size=len(y), replace=False)

num_dev_elements = int(len(y) * dev_pct)
num_test_elements = int(len(y) * test_pct)
num_train_elements = int(len(y) * (1.0 - dev_pct - test_pct))

print(num_dev_elements, num_test_elements, num_train_elements)

train_idxs = shuffled_idxs[:num_train_elements]
dev_idxs = shuffled_idxs[num_train_elements:(num_train_elements + num_dev_elements)]
test_idxs = shuffled_idxs[(num_train_elements + num_dev_elements + 1):]

X = np.array(X)
X[X > 1e300] = 0.0
X[X < -1e300] = 0.0
X[np.isnan(X)] = 0.0
X_train = X[train_idxs]
X_dev = X[dev_idxs]
X_test = X[test_idxs]

y = np.array(y)
y_train = y[train_idxs]
y_dev = y[dev_idxs]
y_test = y[test_idxs]

train_ds = Dataset.from_dict({"image": list(X_train), "label": list(y_train)})
dev_ds = Dataset.from_dict({"image": list(X_dev), "label": list(y_dev)})
test_ds = Dataset.from_dict({"image": list(X_test), "label": list(y_test)})

full_dataset = DatasetDict({"train": train_ds, "dev": dev_ds, "test": test_ds})
full_dataset.save_to_disk("./beagle_chart_to_label.hf")

#np.save("./svg_datasets/imgs_X_train", X_train)
#np.save("./svg_datasets/imgs_X_dev", X_dev)
#np.save("./svg_datasets/imgs_X_test", X_test)
#np.save("./svg_datasets/imgs_y_train", y_train)
#np.save("./svg_datasets/imgs_y_dev", y_dev)
#np.save("./svg_datasets/imgs_y_test", y_test)

4244 8488 29707


  X = np.array(X)


: 