In [None]:
import plot_creation
from ultralytics import YOLO
from PIL import Image
import datetime
from tqdm import tqdm

In [None]:
mapping = {"scatter": 0, "vertical_bar": 1, "horizontal_bar": 2}

def create_plot_type_dataset(path="datasets/plot_type_dataset", start_idx=0, size=10):
    img_path_template = path + "/images/{}.jpg"
    label_path_template = path + "/labels/{}.txt"
    label_data_template = "{plot_type} {x_center} {y_center} {width} {height}\n"
    for i in tqdm(range(start_idx, size)):
        dcts = plot_creation.create_img(img_path=img_path_template.format(str(i).zfill(5)))
        file = open(label_path_template.format(str(i).zfill(5)), "w")
        for dct in dcts:
            plot_type = mapping[dct["plot"]["type"]]
            width_px = 1920
            height_px = 1920
            coords = dct["plot"]["coords_px"]
            x_center = (coords[1][0] + coords[0][0]) / (2 * width_px)
            y_center = (coords[1][1] + coords[0][1]) / (2 * height_px)
            width = (coords[1][0] - coords[0][0]) / width_px
            height = (coords[1][1] - coords[0][1]) / height_px
            file.write(label_data_template.format(plot_type=plot_type, x_center=x_center, y_center=y_center, width=width, height=height))
        file.flush()
        file.close()

In [None]:
create_plot_type_dataset(path="datasets/plot_type_dataset/train", size=15000, start_idx=0)
create_plot_type_dataset(path="datasets/plot_type_dataset/val", size=200)

In [None]:
mapping = {"title": 0, "xtitle": 1, "ytitle": 2, "xtick": 3, "ytick": 4, "xlabel": 5, "ylabel": 6}

model = YOLO("../weights/yolo_plot_type.pt")

def coords_format(point, new_xstart, new_ystart, new_width, new_height):
    point = ((point[0][0] - new_xstart, point[0][1] - new_ystart), (point[1][0] - new_xstart, point[1][1] - new_ystart))
    x_center = (point[1][0] + point[0][0]) / (2 * new_width)
    y_center = (point[1][1] + point[0][1]) / (2 * new_height)
    width = (point[1][0] - point[0][0]) / new_width
    height = (point[1][1] - point[0][1]) / new_height
    return x_center, y_center, width, height

def create_general_info_dataset(path="datasets/general_info_dataset", start_idx=0, size=10):
    img_path_template = path + "/images/{}.jpg"
    label_path_template = path + "/labels/{}.txt"
    label_data_template = "{info_type} {x_center} {y_center} {width} {height}\n"
    idx = start_idx
    while idx < size:
        if idx % 100 == 0:
            print(f"{idx}: {datetime.datetime.now()}")
        dcts = plot_creation.create_img(plot_amount=1)
        preds = model.predict("result.jpg", verbose=False)
        if len(dcts) != len(preds[0].boxes) or preds[0].boxes.conf[0] < 0.95:
            continue
        for dct, box in zip(dcts, preds[0].boxes):
            image = Image.open("result.jpg")
            file = open(label_path_template.format(str(idx).zfill(5)), "w")
            new_borders = tuple(map(int, box.xyxy.tolist()[0]))
            image = image.crop(new_borders)
            image.save(img_path_template.format(str(idx).zfill(5)))
            new_width, new_height = image.size
            for i, info in enumerate(["title", "xlabel", "ylabel"]):
                if info not in dct:
                    continue
                point = dct[info]["coords_px"]
                x_center, y_center, width, height = coords_format(point, new_borders[0], new_borders[1], new_width, new_height)
                file.write(label_data_template.format(info_type=i, x_center=x_center, y_center=y_center, width=width, height=height))
            for i, info in enumerate(["xticks", "yticks", "xlabels", "ylabels"]):
                if info not in dct:
                    continue
                i += 3
                for point in dct[info]["coords_px"]:
                    x_center, y_center, width, height = coords_format(point, new_borders[0], new_borders[1], new_width, new_height)
                    file.write(label_data_template.format(info_type=i, x_center=x_center, y_center=y_center, width=width, height=height))
            file.flush()
            file.close()
            idx += 1

In [None]:
create_general_info_dataset(path="datasets/general_info_dataset/train", size=3000, start_idx=0)
create_general_info_dataset(path="datasets/general_info_dataset/val",size=100)

In [None]:
mapping = {"point": 0}

def coords_format(point, new_xstart, new_ystart, new_width, new_height):
    point = ((point[0][0] - new_xstart, point[0][1] - new_ystart), (point[1][0] - new_xstart, point[1][1] - new_ystart))
    x_center = (point[1][0] + point[0][0]) / (2 * new_width)
    y_center = (point[1][1] + point[0][1]) / (2 * new_height)
    width = (point[1][0] - point[0][0]) / new_width
    height = (point[1][1] - point[0][1]) / new_height
    return x_center, y_center, width, height

def create_points_dataset(path="datasets/points_dataset", start_idx=0, size=10):
    img_path_template = path + "/images/{}.jpg"
    label_path_template = path + "/labels/{}.txt"
    label_data_template = "{info_type} {x_center} {y_center} {width} {height}\n"
    idx = start_idx
    while idx < size:
        if idx % 100 == 0:
            print(f"{idx}: {datetime.datetime.now()}")
        dcts = plot_creation.create_img(plot_amount=1)
        preds = model.predict("result.jpg", verbose=False)
        if len(dcts) != len(preds[0].boxes) or preds[0].boxes.conf[0] < 0.95:
            continue
        for dct, box in zip(dcts, preds[0].boxes):
            image = Image.open("result.jpg")
            file = open(label_path_template.format(str(idx).zfill(5)), "w")
            new_borders = tuple(map(int, box.xyxy.tolist()[0]))
            image = image.crop(new_borders)
            image.save(img_path_template.format(str(idx).zfill(5)))
            new_width, new_height = image.size
            info = "points"
            if info not in dct:
                continue
            for point in dct[info]["coords_px"]:
                x_center, y_center, width, height = coords_format(point, new_borders[0], new_borders[1], new_width, new_height)
                file.write(label_data_template.format(info_type=0, x_center=x_center, y_center=y_center, width=width, height=height))
            file.flush()
            file.close()
            idx += 1

In [None]:
create_points_dataset(path="datasets/points_dataset/train", size=10000, start_idx=0)
create_points_dataset(path="datasets/points_dataset/val",size=100)

In [None]:
mapping = {"bar": 0}

def coords_format(point, new_xstart, new_ystart, new_width, new_height):
    point = ((point[0][0] - new_xstart, point[0][1] - new_ystart), (point[1][0] - new_xstart, point[1][1] - new_ystart))
    x_center = (point[1][0] + point[0][0]) / (2 * new_width)
    y_center = (point[1][1] + point[0][1]) / (2 * new_height)
    width = (point[1][0] - point[0][0]) / new_width
    height = (point[1][1] - point[0][1]) / new_height
    return x_center, y_center, width, height

def create_bars_dataset(path="../datasets/bars_dataset", start_idx=0, size=10):
    img_path_template = path + "/images/{}.jpg"
    label_path_template = path + "/labels/{}.txt"
    label_data_template = "{info_type} {x_center} {y_center} {width} {height}\n"
    idx = start_idx
    while idx < size:
        if idx % 100 == 0:
            print(f"{idx}: {datetime.datetime.now()}")
        dcts = plot_creation_bars_only.create_img(plot_amount=1)
        preds = model.predict("result.jpg", verbose=False)
        if len(dcts) != len(preds[0].boxes) or preds[0].boxes.conf[0] < 0.95:
            continue
        for dct, box in zip(dcts, preds[0].boxes):
            image = Image.open("result.jpg")
            file = open(label_path_template.format(str(idx).zfill(5)), "w")
            new_borders = tuple(map(int, box.xyxy.tolist()[0]))
            image = image.crop(new_borders)
            image.save(img_path_template.format(str(idx).zfill(5)))
            new_width, new_height = image.size
            info = "bars"
            if info not in dct:
                continue
            for point in dct[info]["coords_px"]:
                x_center, y_center, width, height = coords_format(point, new_borders[0], new_borders[1], new_width, new_height)
                file.write(label_data_template.format(info_type=0, x_center=x_center, y_center=y_center, width=width, height=height))
            file.flush()
            file.close()
            idx += 1

In [None]:
create_dataset(path="../datasets/bars_dataset/train", size=3000, start_idx=0)
create_dataset(path="../datasets/bars_dataset/val",size=100)