In [1]:

import cv2

from src.diagram.description_models import *
from src.diagram.ocr.model import *
from src.diagram.struct.model import *


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from rich import print as pp

In [4]:


img = cv2.imread(
    "/home/petr/projects/mltests/dataset/out_image/0e5e7c91e225dd48e344d35c90e62fa7e867890268dd94df219d8b2eff0356aa.0.jpg")
detector_out = DetectorOutput.model_validate_json(open("a.json").read())
ocr_out = OCROutput.model_validate_json(open("b.json").read())

In [5]:
from src.diagram.annotate.diagram import DiagramElementsGenerator

contents = DiagramElementsGenerator(detector_out)()

In [6]:
def iou_metrics(inner, outer):
    xA = max(inner[0], outer[0])
    yA = max(inner[1], outer[1])
    xB = min(inner[2], outer[2])
    yB = min(inner[3], outer[3])
    intersect = max(0, xB - xA + 1) * max(0, yB - yA + 1)
    a_area = (inner[2] - inner[0] + 1) * (inner[3] - inner[1] + 1)
    b_area = (outer[2] - outer[0] + 1) * (outer[3] - outer[1] + 1)
    return dict(
        intersect=intersect,
        a_area=a_area,
        b_area=b_area,
        iou=intersect / float(a_area + b_area - intersect),  # сколько занимает пересечение от объединения
        inters_over_inner=intersect / a_area
        # сколько занимает площадь пересечения относительно площади внутреннего (inner) - 1.0 показывает что весь inner в outer
    )


def rank(data, key, desc=True):
    data = [(key(i), i) for i in data]
    return sorted(data, key=lambda i: i[0], reverse=desc)

In [7]:


class Labeler:
    def __init__(self, detector_data: DetectorOutput, ocr_data: OCROutput):
        self.detector_data = detector_data
        self.ocr_data = ocr_data
        self.label_pool = {i.bbox: i.text for i in self.ocr_data.texts}

    def resolve_internal_labels(self, bbox):
        if not len(self.label_pool): return None
        variants = rank(self.label_pool.items(), lambda x: iou_metrics(x[0], bbox)['inters_over_inner'], desc=True)
        variants = [i for r, i in variants if r > 0.9]
        # собираем текст в порядке чтения l->r, up->down
        variants = rank(variants, key=lambda i: i[0][1] + i[0][0] * 5000, desc=False)
        res = ""
        for r, (bb, txt) in variants:
            res += txt + " "
            self.label_pool.pop(bb)  # убираем из дальнейшего рассмотрения
        return res.strip() or None

    def resolve_external_label(self, bbox):
        return None

    def resolve_external_label_for_line(self, line):
        return None

    def resolve_label_for_process(self, obj):
        return None

    def resolve_label_for_pool(self, obj):
        return None

    def run(self, diag: DiagramContents) -> DiagramContents:
        out = diag.model_copy()
        for i in out.elements:
            if i.type == GBPMNElementType.TASK:
                i.label = self.resolve_internal_labels(i.bbox)
            if i.type == GBPMNElementType.VIRT_LANE:
                i.label = self.resolve_label_for_pool(i.bbox)
            if i.type == GBPMNElementType.VIRT_PROC:
                i.label = self.resolve_label_for_process(i.bbox)
            if i.type in {GBPMNElementType.GATEWAY, GBPMNElementType.EVENT_START, GBPMNElementType.EVENT_END,
                          GBPMNElementType.EVENT_CATCH, GBPMNElementType.EVENT_THROW}:
                i.label = self.resolve_external_label(i.bbox)
        for i in out.links:
            i.label = self.resolve_external_label_for_line(i.line)
        return out


x = Labeler(detector_out, ocr_out)
contents = x.run(contents)

In [8]:
from uuid import uuid4
from src.diagram.description_models import DiagramContents

AGGREGATED_ETYPES = {GBPMNElementType.VIRT_LANE, GBPMNElementType.VIRT_PROC}


class DiagramNestBinder:

    def __init__(self, data: DiagramContents):
        self.data = data.model_copy()
        self.basic_elements = [i
                               for i in self.data.elements
                               if i.type not in AGGREGATED_ETYPES]

    def __scan_internals(self, bbox, source=None, cutoff=0.9):
        variants = rank(source or self.basic_elements,
                        lambda x: iou_metrics(x.bbox, bbox)['inters_over_inner'],
                        desc=True)
        return [i for r, i in variants if r > cutoff]

    def __processes(self):
        return [i for i in self.data.elements if i.type == GBPMNElementType.VIRT_PROC]

    def __lanes(self):
        return [i for i in self.data.elements if i.type == GBPMNElementType.VIRT_LANE]

    def __call__(self) -> DiagramContents:
        # находим лэйны в процессах
        # добавляем единственный дефолтный лэйн если ни одного нет
        # для лэйнов находим содержимое
        # связываем процессы, лэйны и содержимое
        for proc in self.__processes():
            lanes = self.__scan_internals(proc.bbox, source=self.__lanes(), cutoff=0.7)
            if not lanes:
                fallback_lane = GBPMNElement(
                    id=str(uuid4()),
                    label=proc.label,
                    type=GBPMNElementType.VIRT_LANE,
                    bbox=proc.bbox,
                )
                lanes = [fallback_lane]

            lids = []
            for lane in lanes:
                lane_content = self.__scan_internals(lane.bbox, source=self.basic_elements, cutoff=0.9)
                lane = GBPMNLaneElement(**lane.model_dump(),
                                        process_id=proc.id,
                                        nested_ids=[i.id for i in lane_content])
                lids.append(lane.id)
                self.data.drop(lane.id)
                self.data.add(lane)

            proc = GBPMNProcessElement(**proc.model_dump(), lanes_ids=lids)
            self.data.drop(proc.id)
            self.data.add(proc)

        return self.data


contents = DiagramNestBinder(contents)()
pp(contents.elements)

In [9]:
from src.diagram.annotate.tools import dist_pt2bbox

MAX_LINK2OBJ_DISTANCE = 10


class DiagramLinkBinder:

    def __init__(self, data: DiagramContents):
        self.data = data.model_copy()
        self.linkable = [i
                         for i in self.data.elements
                         if i.type not in {GBPMNElementType.VIRT_LANE}]

    def __search_near(self, pos):
        variants = rank(self.linkable, lambda x: dist_pt2bbox(pos, x.bbox), desc=False)
        variants = [i for r, i in variants if r < MAX_LINK2OBJ_DISTANCE]
        return variants[0] if variants else None

    def __call__(self) -> DiagramContents:
        old_count = len(self.data.links)
        links = []
        for i in self.data.links:
            if i.source_id and i.target_id: continue
            s, t = self.__search_near(i.line[0]), self.__search_near(i.line[-1])
            if not s or not t: continue
            i.source_id = s.id
            i.target_id = t.id
            links.append(i)
        self.data.links = links
        new_count = len(links)
        if new_count != old_count:
            print(f"links dropped {old_count} -> {new_count}")
        return self.data


contents = DiagramLinkBinder(contents)()
pp(contents.links)

links dropped 32 -> 27


In [10]:
from src.diagram.description_models import DiagramContents, GBPMNDiagram


class DiagramBuilder:
    def __init__(self, data: DiagramContents):
        self.data = data

    def __elements(self, type):
        return [i for i in self.data.elements if i.type == type]

    def __call__(self) -> GBPMNDiagram:
        processes = []
        for p in self.__elements(GBPMNElementType.VIRT_PROC):
            lanes = [
                GBPMNLane(
                    id=l.id,
                    label=l.label,
                    bbox=l.bbox,
                    objects=[
                        o
                        for o in self.data.elements
                        if o.type not in {GBPMNElementType.VIRT_LANE,
                                          GBPMNElementType.VIRT_PROC} \
                           and o.id in l.nested_ids
                    ]
                )
                for l in self.__elements(GBPMNElementType.VIRT_LANE)
                if l.process_id == p.id
            ]
            processes.append(GBPMNProcess(
                id=p.id, label=p.label, bbox=p.bbox, lanes=lanes
            ))

        return GBPMNDiagram(
            processes=processes,
            flows={(i.source_id, i.target_id): i for i in self.data.links},
            objects={i.id: i for i in self.data.elements},
        )


bpmnd = DiagramBuilder(contents)()
pp(bpmnd)

In [11]:
from io import BytesIO
from matplotlib import pyplot as plt
import networkx as nx
import iplotx as ipx
from src.diagram.annotate.tools import bbox_center
from src.diagram.description_models import DiagramContents, GBPMNDiagram

G_NODE_COLOR = {
    GBPMNElementType.TASK: "yellow",
    GBPMNElementType.GATEWAY: "red",
    GBPMNElementType.EVENT_START: "green",
    GBPMNElementType.EVENT_END: "gray",
    GBPMNElementType.EVENT_CATCH: "blue",
    GBPMNElementType.EVENT_THROW: "cyan",
}


class GraphBuilder:
    def __init__(self, contents: DiagramContents, diagram: GBPMNDiagram):
        self.contents = contents
        self.diagram = diagram
        self.graph = nx.DiGraph()
        self.push_nodes()
        self.push_edges()

    def push_nodes(self):
        for p in self.diagram.processes:
            for l in p.lanes:
                for o in l.objects:
                    self.graph.add_node(o.id, **o.model_dump())

    def push_edges(self):
        for (s1, s2), e in self.diagram.flows.items():
            if e.type == GBPMNFlowType.SEQUENCE:
                self.graph.add_edge(s1, s2, **e.model_dump())

    def create_layout(self):
        pos = {
            i: bbox_center(v['bbox'])
            for i, v in self.graph.nodes.data()
        }
        return nx.spring_layout(self.graph, pos=pos, fixed=self.graph.nodes)

    def visualize(self, sz=(640, 480), dpi=80):
        vertex_color = [G_NODE_COLOR[v['type']] for _, v in self.graph.nodes.data()]
        layout = self.create_layout()
        fig, ax = plt.subplots(figsize=(sz[0] / dpi, sz[1] / dpi), dpi=dpi)
        ipx.plot(self.graph, ax=ax, layout=layout, vertex_facecolor=vertex_color)
        buf = BytesIO()
        fig.savefig(buf, format='png', dpi=dpi, bbox_inches='tight')
        buf.seek(0)
        png_bytes = buf.getvalue()
        plt.close(fig)
        return png_bytes

    def __call__(self) -> nx.DiGraph:
        return self.graph


# layout = nx.bfs_layout(graph, next(iter(graph.nodes())))
gb = GraphBuilder(contents, bpmnd)
graph = gb()

with open("test.png", "wb") as f:
    f.write(gb.visualize())


In [40]:
import numpy as np
import importlib


In [20]:
from src.renderer.renderbpmn import BPMNRenderer, save_all

renderer = BPMNRenderer()

In [46]:
all_pts = np.array(list(gb.create_layout().values()))
p_min, p_max = all_pts.min(axis=0), all_pts.max(axis=0)
b_range = p_max - p_min


(array([ 74.5, 254. ]), array([1043.5,  745. ]))

In [48]:
import src.renderer.codegen as codegen
importlib.reload(codegen)


cg = codegen.GraphBPMNCodegen()
code = cg(graph, gb.create_layout(), scale=2)
save_all(await renderer.render_by_code(code))
