In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
from atria_datasets.registry.document_classification.rvlcdip import *  # noqa
from atria_datasets.registry.document_classification.tobacco3482 import *  # noqa
from atria_datasets.registry.image_classification.cifar10 import *  # noqa
from atria_datasets.registry.image_classification.cifar10_huggingface import *  # noqa
from atria_datasets.registry.image_classification.mnist import *  # noqa
# from atria_datasets.registry.layout_analysis.doclaynet import *  # noqa
# from atria_datasets.registry.layout_analysis.icdar2019 import *  # noqa
# from atria_datasets.registry.layout_analysis.publaynet import *  # noqa
from atria_datasets.registry.ser.cord import *  # noqa
# from atria_datasets.registry.ser.docbank import *  # noqa
# from atria_datasets.registry.ser.docile import *  # noqa
from atria_datasets.registry.ser.funsd import *  # noqa
from atria_datasets.registry.ser.sroie import *  # noqa
from atria_datasets.registry.ser.wild_receipts import *  # noqa
# from atria_datasets.registry.table_extraction.fintabnet import *  # noqa
# from atria_datasets.registry.table_extraction.icdar2013 import *  # noqa
# from atria_datasets.registry.table_extraction.pubtables1m import *  # noqa
# from atria_datasets.registry.vqa.docvqa import *  # noqa
# from atria_datasets.registry.vqa.due import *  # noqa



In [3]:
# first lets 
from atria_datasets import DATASET

available_modules = DATASET.list_all_modules()
print("Available dataset modules:")
for module in available_modules:
    print(f"- {module}")

Available dataset modules:
- rvlcdip/image
- rvlcdip/image_with_ocr
- rvlcdip/image_with_ocr_1k
- tobacco3482/image_only
- tobacco3482/image_with_ocr
- cifar10/1k
- huggingface_cifar10/plain_text
- huggingface_cifar10/plain_text_1k
- mnist/default
- mnist/1k
- cord
- funsd
- sroie
- wild_receipts


In [8]:
# first lets load the image classification datasets
from atria_datasets import load_dataset_config

dataset_config = load_dataset_config("rvlcdip/image_with_ocr_1k")
dataset = dataset_config.build()



In [21]:
x = None
for sample in dataset.train:
    x = sample
    break

In [None]:
from typing import Optional, Union
import bs4

from atria_types._generic._bounding_box import BoundingBox
from atria_types._generic._doc_content import TextElement

import bs4
import networkx
import tqdm

class OCRGraphNode(BaseModel):
    id: Union[int, str]
    word: Optional[str] = None
    level: Optional[OCRLevel] = None
    bbox: Optional[BoundingBox] = None
    segment_level_bbox: Optional[BoundingBox] = None
    conf: Optional[float] = None
    angle: Optional[float] = None
    label: Optional[Label] = None


class OCRGraphLink(BaseModel):
    source: Union[int, str]
    target: Union[int, str]
    relation: Optional[str] = None


class OCRGraph(BaseModel):
    directed: Optional[bool]
    multigraph: Optional[bool]
    graph: Optional[dict]
    nodes: List[OCRGraphNode]
    links: List[OCRGraphLink]

    @property
    def words(self) -> List[str]:
        return [node.word for node in self.nodes if node.level == OCRLevel.WORD]

    @property
    def word_bboxes(self) -> List[BoundingBox]:
        return [node.bbox for node in self.nodes if node.level == OCRLevel.WORD]

    @property
    def word_segment_level_bboxes(self) -> List[BoundingBox]:
        return [
            node.segment_level_bbox
            for node in self.nodes
            if node.level == OCRLevel.WORD
        ]

    @property
    def word_labels(self) -> List[Label]:
        return [node.label for node in self.nodes if node.level == OCRLevel.WORD]

    @property
    def word_confs(self) -> List[float]:
        return [node.conf for node in self.nodes if node.level == OCRLevel.WORD]

    @property
    def word_angles(self) -> List[float]:
        return [node.angle for node in self.nodes if node.level == OCRLevel.WORD]

    @classmethod
    def from_word_level_content(
        cls,
        words: List[str],
        word_bboxes: List[BoundingBox],
        word_labels: List[Label],
        word_segment_level_bboxes: List[BoundingBox] = None,
        word_confs: List[float] = None,
        word_angles: List[float] = None,
    ) -> "OCRGraph":
        nodes = []
        for i, word in enumerate(words):
            node = OCRGraphNode(
                id=i,
                word=word,
                level=OCRLevel.WORD,
                bbox=word_bboxes[i] if word_bboxes is not None else None,
                segment_level_bbox=(
                    word_segment_level_bboxes[i]
                    if word_segment_level_bboxes is not None
                    else None
                ),
                conf=word_confs[i] if word_confs is not None else None,
                angle=word_angles[i] if word_angles is not None else None,
                label=word_labels[i] if word_labels is not None else None,
            )
            nodes.append(node)

        links = []
        for i in range(len(words) - 1):
            links.append(
                OCRGraphLink(
                    source=i,
                    target=i + 1,
                    relation="next_word",
                )
            )

        return OCRGraph(
            directed=True, multigraph=False, graph={}, nodes=nodes, links=links
        )

class HOCRGraphParser:
    def __init__(self, ocr: str):
        self._soup = bs4.BeautifulSoup(ocr, features="xml")
        self._image_size = self._get_image_size()

    def _get_image_size(self) -> tuple[int, int]:
        page = self._soup.find("div", {"class": "ocr_page"})
        if page and "title" in page.attrs:
            bbox = page["title"].split("bbox")[1].split(";")[0].strip()
            x1, y1, x2, y2 = map(int, bbox.split())
            return x2, y2
        return (1, 1)

    def _add_node(
        self,
        graph: networkx.DiGraph,
        node_id: int,
        tag: bs4.Tag,
        level: str,
        parent_id: str | None = None,
    ):
        title = tag.get("title", "")
        bbox = None
        conf = None
        angle = 0.0

        # Parse bbox
        if "bbox" in title:
            bbox_str = title.split("bbox")[1].split(";")[0].strip()
            x1, y1, x2, y2 = map(int, bbox_str.split())
            w, h = self._image_size
            bbox = BoundingBox(value=[x1 / w, y1 / h, x2 / w, y2 / h])

        # Parse confidence
        if "x_wconf" in title:
            conf = float(title.split("x_wconf")[1].strip().split()[0])

        # Parse angle if available
        if "textangle" in title:
            angle = float(title.split("textangle")[1].split(";")[0])

        # Build node
        text = tag.get_text(strip=True)
        graph.add_node(
            node_id,
            **OCRGraphNode(
                id=node_id,
                word=text if level == OCRLevel.WORD.value else None,
                level=OCRLevel(level),
                bbox=bbox,
                conf=conf,
                angle=angle,
            ).model_dump(),
        )

        if parent_id:
            graph.add_edge(parent_id, node_id, relation="child")

        return node_id

    def parse_graph(self) -> dict:
        """
        Efficiently parse HOCR into a NetworkX directed graph using predictable structure.
        Returns:
            dict: Graph data in node-link format.
        """
        import networkx as nx
        from networkx.readwrite import json_graph

        graph = nx.DiGraph()
        node_id = 0

        add_node = self._add_node  # avoid attribute lookups

        pbar = tqdm.tqdm()
        for page_tag in self._soup.select("div.ocr_page"):
            page_id = add_node(graph, node_id, page_tag, "page")
            node_id += 1
            pbar.update(1)

            for block in page_tag.select("div.ocr_carea"):
                block_id = add_node(graph, node_id, block, "block", parent_id=page_id)
                node_id += 1
                pbar.update(1)

                for par in block.select("p.ocr_par"):
                    par_id = add_node(
                        graph, node_id, par, "paragraph", parent_id=block_id
                    )
                    node_id += 1
                    pbar.update(1)

                    for line in par.select("span.ocr_line"):
                        line_id = add_node(
                            graph, node_id, line, "line", parent_id=par_id
                        )
                        node_id += 1
                        pbar.update(1)

                        for word in line.select("span.ocrx_word"):
                            add_node(graph, node_id, word, "word", parent_id=line_id)
                            node_id += 1
                            pbar.update(1)

        return json_graph.node_link_data(graph)
class HOCRProcessor:
    @staticmethod
    def parse(raw_ocr: str) -> list[TextElement]:
        soup = bs4.BeautifulSoup(raw_ocr, features="xml")

        # Extract image size
        pages = soup.findAll("div", {"class": "ocr_page"})
        image_size_str = pages[0]["title"].split("; bbox")[1]
        w, h = map(int, image_size_str[4 : image_size_str.find(";")].split())
        

        # Extract words and their properties
        text_elements: list[TextElement] = []
        ocr_lines = soup.findAll("span", {"class": "ocr_line"})
        ocr_words = soup.findAll("span", {"class": "ocrx_word"})
        for word in ocr_words:
            title = word["title"]
            conf = float(title[title.find(";") + 10 :])
            if word.text.strip() == "":
                continue

            # Get text angle from line title
            textangle = 0.0
            parent_title = word.parent["title"]
            if "textangle" in parent_title:
                textangle = float(parent_title.split("textangle")[1][1:3])

            x1, y1, x2, y2 = map(int, title[5 : title.find(";")].split())
            text_elements.append(
                TextElement(
                    text=word.text.strip(),
                    bbox=BoundingBox(
                        value=[x1 / w, y1 / h, x2 / w, y2 / h], normalized=True
                    ),
                    conf=conf,
                    angle=textangle,
                )
            )

        return text_elements
    
HOCRProcessor.parse(x.ocr.content)

<?xml version="1.0" encoding="utf-8"?>
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
<html lang="en" xml:lang="en" xmlns="http://www.w3.org/1999/xhtml">
<head>
<title/>
<meta content="text/html;charset=utf-8" http-equiv="Content-Type"/>
<meta content="tesseract 5.0.1-9-g31a968" name="ocr-system"/>
<meta content="ocr_page ocr_carea ocr_par ocr_line ocrx_word ocrp_wconf" name="ocr-capabilities"/>
</head>
<body>
<div class="ocr_page" id="page_1" title='image "/tmp/tess_c8ql3now.JPEG"; bbox 0 0 754 1000; ppageno 0; scan_res 70 70'>
<div class="ocr_carea" id="block_1_1" title="bbox 112 315 151 351">
<p class="ocr_par" id="par_1_1" lang="eng" title="bbox 112 315 151 351">
<span class="ocr_line" id="line_1_1" title="bbox 112 315 134 323; baseline 0 0; x_size 20; x_descenders 5; x_ascenders 5">
<span class="ocrx_word" id="word_1_1" title="bbox 112 315 134 323; x_wconf 77">To:</span>
</span>
<span class="ocr_line" id="l

  pages = soup.findAll("div", {"class": "ocr_page"})
  ocr_words = soup.findAll("span", {"class": "ocrx_word"})


[TextElement(
     text='To:',
     bbox=BoundingBox(
         value=[0.14854111405835543, 0.315, 0.17771883289124668, 0.323],
         mode=<BoundingBoxMode.XYXY: 'xyxy'>,
         normalized=True
     ),
     segment_bbox=None,
     conf=77.0,
     angle=0.0
 ),
 TextElement(
     text='FROM:',
     bbox=BoundingBox(
         value=[0.14986737400530503, 0.342, 0.2002652519893899, 0.351],
         mode=<BoundingBoxMode.XYXY: 'xyxy'>,
         normalized=True
     ),
     segment_bbox=None,
     conf=87.0,
     angle=0.0
 ),
 TextElement(
     text='SUBJECT:',
     bbox=BoundingBox(
         value=[0.14854111405835543, 0.369, 0.23209549071618038, 0.378],
         mode=<BoundingBoxMode.XYXY: 'xyxy'>,
         normalized=True
     ),
     segment_bbox=None,
     conf=91.0,
     angle=0.0
 ),
 TextElement(
     text='According',
     bbox=BoundingBox(
         value=[0.14721485411140584, 0.423, 0.2453580901856764, 0.435],
         mode=<BoundingBoxMode.XYXY: 'xyxy'>,
         normalized=T