In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install python-dotenv

In [None]:
!pip install -U attrdict3

# Data processing

In [None]:
from argparse import ArgumentParser
import os
from attrdict import AttrDict
import yaml

def create_folder(folder_path : str) -> None:
    """create a folder if not exists

    Args:
        folder_path (str): path
    """
    if not os.path.exists(folder_path):
        os.mkdir(folder_path)

    return

def get_config(name : str) -> AttrDict:
    """get yaml config file

    Args:
        name (str): yaml file name without extension

    Returns:
        AttrDict: config
    """
    with open(CONFIGS / f'{name}.yaml') as fileobj:
        config = AttrDict(yaml.safe_load(fileobj))
    return config

def project_tree() -> None:
    """ Create the project tree folder
    """
    create_folder(DATA)
    create_folder(OUTPUTS)
    create_folder(RUNS)
    create_folder(RESULTS)
    create_folder(TRAIN_SAMPLES)
    create_folder(TEST_SAMPLES)
    create_folder(CHECKPOINTS)
    return

def set_preprocessing(args: ArgumentParser) -> None:
    """ Set preprocessings args

    Args:
        args (ArgumentParser):
    """
    with open(CONFIGS / 'base.yaml') as fileobj:
        cfg_preprocessing = dict(yaml.safe_load(fileobj))
    cfg_preprocessing['FEATURES']['add_geom'] = args.add_geom
    cfg_preprocessing['FEATURES']['add_embs'] = args.add_embs
    cfg_preprocessing['FEATURES']['add_hist'] = args.add_hist
    cfg_preprocessing['FEATURES']['add_visual'] = args.add_visual
    cfg_preprocessing['FEATURES']['add_eweights'] = args.add_eweights
    cfg_preprocessing['FEATURES']['num_polar_bins'] = args.num_polar_bins
    cfg_preprocessing['LOADER']['src_data'] = args.src_data
    cfg_preprocessing['GRAPHS']['data_type'] = args.data_type
    cfg_preprocessing['GRAPHS']['edge_type'] = args.edge_type
    cfg_preprocessing['GRAPHS']['node_granularity'] = args.node_granularity

    with open(CONFIGS / 'preprocessing.yaml', 'w') as f:
        yaml.dump(cfg_preprocessing, f)
    return

In [None]:
#!unzip /content/drive/MyDrive/doc2graph-master/tutorial/dataset.zip -d /content/drive/MyDrive/doc2graph-master/tutorial/funsd

In [None]:
from PIL import Image, ImageDraw


image_name = '/content/drive/MyDrive/doc2graph-master/tutorial/funsd/dataset/training_data/images/0000971160.png' #! change this to see different outputs from FUNSD, or pass your own image!
# 82491256.png
image_path = str(image_name)
image = Image.open(image_path).convert('RGB')
image

In [None]:
!pip install easyocr

In [None]:
!pip install --upgrade pytesseract

In [None]:
!sudo apt install tesseract-ocr

In [None]:
import torch
import torchvision
import numpy as np
from scipy.optimize import linprog
import os
from PIL import ImageDraw, Image
import json
import pytesseract
from pytesseract import Output


def scale_back(r, w, h): return [int(r[0]*w),
                                 int(r[1]*h), int(r[2]*w), int(r[3]*h)]


def center(r): return ((r[0] + r[2]) / 2, (r[1] + r[3]) / 2)


def isIn(c, r):
    if c[0] < r[0] or c[0] > r[2]:
        return False
    elif c[1] < r[1] or c[1] > r[3]:
        return False
    else:
        return True


def match_pred_w_gt(bbox_preds: torch.Tensor, bbox_gts: torch.Tensor, links_pair: list):
    bbox_iou = torchvision.ops.box_iou(boxes1=bbox_preds, boxes2=bbox_gts)
    bbox_iou = bbox_iou.numpy()

    A_ub = np.zeros(shape=(
        bbox_iou.shape[0] + bbox_iou.shape[1], bbox_iou.shape[0] * bbox_iou.shape[1]))
    for r in range(bbox_iou.shape[0]):
        st = r * bbox_iou.shape[1]
        A_ub[r, st:st + bbox_iou.shape[1]] = 1
    for j in range(bbox_iou.shape[1]):
        r = j + bbox_iou.shape[0]
        A_ub[r, j::bbox_iou.shape[1]] = 1
    b_ub = np.ones(shape=A_ub.shape[0])

    assignaments_score = linprog(
        c=-bbox_iou.reshape(-1), A_ub=A_ub, b_ub=b_ub, bounds=(0, 1), method="highs-ds")
    if not assignaments_score.success:
        print("Optimization FAILED")
    assignaments_score = assignaments_score.x.reshape(bbox_iou.shape)
    assignaments_ids = assignaments_score.argmax(axis=1)

    # matched
    opt_assignaments = {}
    for idx in range(assignaments_score.shape[0]):
        if (bbox_iou[idx, assignaments_ids[idx]] > 0.5) and (assignaments_score[idx, assignaments_ids[idx]] > 0.9):
            opt_assignaments[idx] = assignaments_ids[idx]
    # unmatched predictions
    false_positive = [idx for idx in range(
        bbox_preds.shape[0]) if idx not in opt_assignaments]
    # unmatched gts
    false_negative = [idx for idx in range(
        bbox_gts.shape[0]) if idx not in opt_assignaments.values()]

    gt2pred = {v: k for k, v in opt_assignaments.items()}
    link_false_neg = []
    for link in links_pair:
        if link[0] in false_negative or link[1] in false_negative:
            link_false_neg.append(link)

    if len(links_pair) != 0:
        rate = len(link_false_neg) / len(links_pair)
    else:
        rate = 0
    return {"pred2gt": opt_assignaments, "gt2pred": gt2pred, "false_positive": false_positive, "false_negative": false_negative, "n_link_fn": int(len(link_false_neg) / 2), "link_loss": rate, "entity_loss": len(false_positive) / (len(false_positive) + len(opt_assignaments.keys()))}


def get_objects(path, mode):
    # TODO given a document, apply OCR or Yolo to detect either words or entities.
    return


def load_predictions(path_preds, path_gts, path_images, debug=False):
    # TODO read txt file and pass bounding box to the other function.

    boxs_preds = []
    boxs_gts = []
    links_gts = []
    labels_gts = []
    texts_ocr = []
    all_paths = []

    for img in os.listdir(path_images):
        all_paths.append(os.path.join(path_images, img))
        w, h = Image.open(os.path.join(path_images, img)).size
        texts = pytesseract.image_to_data(Image.open(
            os.path.join(path_images, img)), output_type=Output.DICT)
        tp = []
        n_elements = len(texts['level'])
        for t in range(n_elements):
            if int(texts['conf'][t]) > 50 and texts['text'][t] != ' ':
                b = [texts['left'][t], texts['top'][t], texts['left'][t] +
                     texts['width'][t], texts['top'][t] + texts['height'][t]]
                tp.append([b, texts['text'][t]])
        texts_ocr.append(tp)
        preds_name = img.split(".")[0] + '.txt'
        with open(os.path.join(path_preds, preds_name), 'r') as preds:
            lines = preds.readlines()
            boxs = list()
            for line in lines:
                scaled = scale_back([float(c)
                                    for c in line[:-1].split(" ")[1:]], w, h)
                sw, sh = scaled[2] / 2, scaled[3] / 2
                boxs.append([scaled[0] - sw, scaled[1] - sh,
                            scaled[0] + sw, scaled[1] + sh])
            boxs_preds.append(boxs)

        gts_name = img.split(".")[0] + '.json'
        with open(os.path.join(path_gts, gts_name), 'r') as f:
            form = json.load(f)['form']
            boxs = list()
            pair_labels = []
            ids = []
            labels = []
            for elem in form:
                boxs.append([float(e) for e in elem['box']])
                ids.append(elem['id'])
                labels.append(elem['label'])
                [pair_labels.append(pair) for pair in elem['linking']]

            for p, pair in enumerate(pair_labels):
                pair_labels[p] = [ids.index(pair[0]), ids.index(pair[1])]

            boxs_gts.append(boxs)
            links_gts.append(pair_labels)
            labels_gts.append(labels)

    all_links = []
    all_preds = []
    all_labels = []
    all_texts = []
    dropped_links = 0
    dropped_entity = 0

    for p in range(len(boxs_preds)):
        d = match_pred_w_gt(torch.tensor(
            boxs_preds[p]), torch.tensor(boxs_gts[p]), links_gts[p])
        dropped_links += d['link_loss']
        dropped_entity += d['entity_loss']
        links = list()

        for link in links_gts[p]:
            if link[0] in d['false_negative'] or link[1] in d['false_negative']:
                continue
            else:
                links.append([d['gt2pred'][link[0]], d['gt2pred'][link[1]]])
        all_links.append(links)

        preds = []
        labels = []
        texts = []
        for b, box in enumerate(boxs_preds[p]):
            if b in d['false_positive']:
                preds.append(box)
                labels.append('other')
            else:
                gt_id = d['pred2gt'][b]
                preds.append(box)
                labels.append(labels_gts[p][gt_id])

            text = ''
            for tocr in texts_ocr[p]:
                if isIn(center(tocr[0]), box):
                    text += tocr[1] + ' '

            texts.append(text)

        all_preds.append(preds)
        all_labels.append(labels)
        all_texts.append(texts)
    print(dropped_links / len(boxs_preds), dropped_entity / len(boxs_preds))

    if debug:
        # random.seed(35)
        # rand_idx = random.randint(0, len(os.listdir(path_images)))
        print(all_texts[0])
        rand_idx = 0
        img = Image.open(os.path.join(path_images, os.listdir(
            path_images)[rand_idx])).convert('RGB')
        draw = ImageDraw.Draw(img)

        rand_boxs_preds = boxs_preds[rand_idx]
        rand_boxs_gts = boxs_gts[rand_idx]

        for box in rand_boxs_gts:
            draw.rectangle(box, outline='blue', width=3)
        for box in rand_boxs_preds:
            draw.rectangle(box, outline='red', width=3)

        d = match_pred_w_gt(torch.tensor(rand_boxs_preds),
                            torch.tensor(rand_boxs_gts), links_gts[rand_idx])
        print(d)
        for idx in d['pred2gt'].keys():
            draw.rectangle(rand_boxs_preds[idx], outline='green', width=3)

        link_true_pos = list()
        link_false_neg = list()
        for link in links_gts[rand_idx]:
            if link[0] in d['false_negative'] or link[1] in d['false_negative']:
                link_false_neg.append(link)
                start = rand_boxs_gts[link[0]]
                end = rand_boxs_gts[link[1]]
                draw.line((center(start), center(end)), fill='red', width=3)
            else:
                link_true_pos.append(link)
                start = rand_boxs_preds[d['gt2pred'][link[0]]]
                end = rand_boxs_preds[d['gt2pred'][link[1]]]
                draw.line((center(start), center(end)), fill='green', width=3)

        precision = 0
        recall = 0
        for idx, gt in enumerate(boxs_gts):
            d = match_pred_w_gt(torch.tensor(
                boxs_preds[idx]), torch.tensor(gt), links_gts[rand_idx])
            bbox_true_positive = len(d["pred2gt"])
            p = bbox_true_positive / \
                (bbox_true_positive + len(d["false_positive"]))
            r = bbox_true_positive / \
                (bbox_true_positive + len(d["false_negative"]))
            # f1 += (2 * p * r) / (p + r)
            precision += p
            recall += r

        precision = precision / len(boxs_gts)
        recall = recall / len(boxs_gts)
        f1 = (2 * precision * recall) / (precision + recall)
        # print(f1, precision, recall)

        img.save('prova.png')

    return all_paths, all_preds, all_links, all_labels, all_texts

# if __name__ == "__main__":
#     path_preds = '/content/drive/MyDrive/doc2graph-master/tutorial/funsd'
#     path_images = '/content/drive/MyDrive/doc2graph-master/tutorial/funsd/dataset/training_data/images'
#     path_gts = '/content/drive/MyDrive/doc2graph-master/tutorial/funsd/dataset/training_data/annotations'
#     load_predictions(path_preds, path_gts, path_images, debug=True)

In [None]:
import easyocr

reader = easyocr.Reader(['en']) #! support multilingual!

def apply_ocr(path):
    result = reader.readtext(path, paragraph=True)
    boxs, texts = list(), list()

    # transform the OCR result in a handle format
    for r in result:
        box = [int(r[0][0][0]), int(r[0][0][1]), int(r[0][2][0]), int(r[0][2][1])]
        boxs.append(box)
        texts.append(r[1])

    return boxs, texts

def draw_results(img, boxs, links):
    draw = ImageDraw.Draw(img)

    for box in boxs:
        draw.rectangle(box, outline='blue', width=3)

    if links:
        for idx in range(len(links['src'])):
            key_center = center(boxs[links['src'][idx]])
            value_center = center(boxs[links['dst'][idx]])
            draw.line((key_center, value_center), fill='violet', width=3)

In [None]:
#! get text boxes and contents
boxs, texts = apply_ocr(image_path)
draw_results(image, boxs, [])
image

In [None]:
from math import sqrt
from typing import Tuple
import cv2
import numpy as np
import torch
import math

def polar(rect_src : list, rect_dst : list) -> Tuple[int, int]:
    """Compute distance and angle from src to dst bounding boxes (poolar coordinates considering the src as the center)
    Args:
        rect_src (list) : source rectangle coordinates
        rect_dst (list) : destination rectangle coordinates

    Returns:
        tuple (ints): distance and angle
    """

    # check relative position
    left = (rect_dst[2] - rect_src[0]) <= 0
    bottom = (rect_src[3] - rect_dst[1]) <= 0
    right = (rect_src[2] - rect_dst[0]) <= 0
    top = (rect_dst[3] - rect_src[1]) <= 0

    vp_intersect = (rect_src[0] <= rect_dst[2] and rect_dst[0] <= rect_src[2]) # True if two rects "see" each other vertically, above or under
    hp_intersect = (rect_src[1] <= rect_dst[3] and rect_dst[1] <= rect_src[3]) # True if two rects "see" each other horizontally, right or left
    rect_intersect = vp_intersect and hp_intersect

    center = lambda rect: ((rect[2]+rect[0])/2, (rect[3]+rect[1])/2)

    # evaluate reciprocal position
    sc = center(rect_src)
    ec = center(rect_dst)
    new_ec = (ec[0] - sc[0], ec[1] - sc[1])
    angle = int(math.degrees(math.atan2(new_ec[1], new_ec[0])) % 360)

    if rect_intersect:
        return 0, angle
    elif top and left:
        a, b = (rect_dst[2] - rect_src[0]), (rect_dst[3] - rect_src[1])
        return int(sqrt(a**2 + b**2)), angle
    elif left and bottom:
        a, b = (rect_dst[2] - rect_src[0]), (rect_dst[1] - rect_src[3])
        return int(sqrt(a**2 + b**2)), angle
    elif bottom and right:
        a, b = (rect_dst[0] - rect_src[2]), (rect_dst[1] - rect_src[3])
        return int(sqrt(a**2 + b**2)), angle
    elif right and top:
        a, b = (rect_dst[0] - rect_src[2]), (rect_dst[3] - rect_src[1])
        return int(sqrt(a**2 + b**2)), angle
    elif left:
        return (rect_src[0] - rect_dst[2]), angle
    elif right:
        return (rect_dst[0] - rect_src[2]), angle
    elif bottom:
        return (rect_dst[1] - rect_src[3]), angle
    elif top:
        return (rect_src[1] - rect_dst[3]), angle

def transform_image(img_path : str, scale_image=1.0) -> torch.Tensor:
    """ Transform image to torch.Tensor

    Args:
        img_path (str) : where the image is stored
        scale_image (float) : how much scale the image
    """

    np_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
    width = int(np_img.shape[1] * scale_image)
    height = int(np_img.shape[0] * scale_image)
    new_size = (width, height)
    np_img = cv2.resize(np_img,new_size)
    img = cv2.cvtColor(np_img, cv2.COLOR_BGR2GRAY)
    img = img[None,None,:,:]
    img = img.astype(np.float32)
    img = torch.from_numpy(img)
    img = 1.0 - img / 128.0

    return img

def get_histogram(contents : list) -> list:
    """Create histogram of content given a text.

    Args;
        contents (list)

    Returns:
        list of [x, y, z] - 3-dimension list with float values summing up to 1 where:
            - x is the % of literals inside the text
            - y is the % of numbers inside the text
            - z is the % of other symbols i.e. @, #, .., inside the text
    """

    c_histograms = list()

    for token in contents:
        num_symbols = 0 # all
        num_literals = 0 # A, B etc.
        num_figures = 0 # 1, 2, etc.
        num_others = 0 # !, @, etc.

        histogram = [0.0000, 0.0000, 0.0000, 0.0000]

        for symbol in token.replace(" ", ""):
            if symbol.isalpha():
                num_literals += 1
            elif symbol.isdigit():
                num_figures += 1
            else:
                num_others += 1
            num_symbols += 1

        if num_symbols != 0:
            histogram[0] = num_literals / num_symbols
            histogram[1] = num_figures / num_symbols
            histogram[2] = num_others / num_symbols

            # keep sum 1 after truncate
            if sum(histogram) != 1.0:
                diff = 1.0 - sum(histogram)
                m = max(histogram) + diff
                histogram[histogram.index(max(histogram))] = m

        # if symbols not recognized at all or empty, sum everything at 1 in the last
        if histogram[0:3] == [0.0,0.0,0.0]:
            histogram[3] = 1.0

        c_histograms.append(histogram)

    return c_histograms

def to_bin(dist :int, angle : int, b=8) -> torch.Tensor:
    """ Discretize the space into equal "bins": return a distance and angle into a number between 0 and 1.

    Args:
        dist (int): distance in terms of pixel, given by "polar()" util function
        angle (int): angle between 0 and 360, given by "polar()" util function
        b (int): number of bins, MUST be power of 2

    Returns:
        torch.Tensor: new distance and angle (binary encoded)

    """
    def isPowerOfTwo(x):
        return (x and (not(x & (x - 1))) )

    # dist
    assert isPowerOfTwo(b)
    m = max(dist) / b
    new_dist = []
    for d in dist:
        bin = int(d / m)
        if bin >= b: bin = b - 1
        bin = [int(x) for x in list('{0:0b}'.format(bin))]
        while len(bin) < sqrt(b): bin.insert(0, 0)
        new_dist.append(bin)

    # angle
    amplitude = 360 / b
    new_angle = []
    for a in angle:
        bin = (a - amplitude / 2)
        bin = int(bin / amplitude)
        bin = [int(x) for x in list('{0:0b}'.format(bin))]
        while len(bin) < sqrt(b): bin.insert(0, 0)
        new_angle.append(bin)

    return torch.cat([torch.tensor(new_dist, dtype=torch.float32), torch.tensor(new_angle, dtype=torch.float32)], dim=1)

In [None]:
!pip install dgl

In [None]:
import json
import os
from PIL import Image, ImageDraw
from typing import Tuple
import torch
import dgl
import random
import numpy as np
from tqdm import tqdm
import xml.etree.ElementTree as ET
import easyocr


class GraphBuilder():

    def __init__(self):
        self.cfg_preprocessing = get_config('base')
        self.edge_type = self.cfg_preprocessing.GRAPHS.edge_type
        self.data_type = self.cfg_preprocessing.GRAPHS.data_type
        self.node_granularity = self.cfg_preprocessing.GRAPHS.node_granularity
        random.seed = 42
        return

    def get_graph(self, src_path : str, src_data : str) -> Tuple[list, list, list, list]:
        """ Given the source, it returns a graph

        Args:
            src_path (str) : path to source data
            src_data (str) : either FUNSD, PAU or CUSTOM

        Returns:
            tuple (lists) : graphs, nodes and edge labels, features
        """

        if src_data == 'FUNSD':
            return self.__fromFUNSD(src_path)
        elif src_data == 'PAU':
            return self.__fromPAU(src_path)
        elif src_data == 'CUSTOM':
            if self.data_type == 'img':
                return self.__fromIMG(src_path)
            elif self.data_type == 'pdf':
                return self.__fromPDF()
            else:
                raise Exception('GraphBuilder exception: data type invalid. Choose from ["img", "pdf"]')
        else:
            raise Exception('GraphBuilder exception: source data invalid. Choose from ["FUNSD", "PAU", "CUSTOM"]')

    def balance_edges(self, g : dgl.DGLGraph, cls=None ) -> dgl.DGLGraph:
        """ if cls (class) is not None, but an integer instead, balance that class to be equal to the sum of the other classes

        Args:
            g (DGLGraph) : a DGL graph
            cls (int) : class number, if any

        Returns:
            g (DGLGraph) : the new balanced graph
        """

        edge_targets = g.edata['label']
        u, v = g.all_edges(form='uv')
        edges_list = list()
        for e in zip(u.tolist(), v.tolist()):
            edges_list.append([e[0], e[1]])

        if type(cls) is int:
            to_remove = (edge_targets == cls)
            indices_to_remove = to_remove.nonzero().flatten().tolist()

            for _ in range(int((edge_targets != cls).sum()/2)):
                indeces_to_save = [random.choice(indices_to_remove)]
                edge = edges_list[indeces_to_save[0]]

                for index in sorted(indeces_to_save, reverse=True):
                    del indices_to_remove[indices_to_remove.index(index)]

            indices_to_remove = torch.flatten(torch.tensor(indices_to_remove, dtype=torch.int32))
            g = dgl.remove_edges(g, indices_to_remove)
            return g

        else:
            raise Exception("Select a class to balance (an integer ranging from 0 to num_edge_classes).")

    def get_info(self):
        """ returns graph information
        """
        print(f"-> edge type: {self.edge_type}")

    def fully_connected(self, ids : list) -> Tuple[list, list]:
        """ create fully connected graph

        Args:
            ids (list) : list of node indices

        Returns:
            u, v (lists) : lists of indices
        """
        u, v = list(), list()
        for id in ids:
            u.extend([id for i in range(len(ids)) if i != id])
            v.extend([i for i in range(len(ids)) if i != id])
        return u, v

    def knn_connection(self, size : tuple, bboxs : list, k = 10) -> Tuple[list, list]:
        """ Given a list of bounding boxes, find for each of them their k nearest ones.

        Args:
            size (tuple) : width and height of the image
            bboxs (list) : list of bounding box coordinates
            k (int) : k of the knn algorithm

        Returns:
            u, v (lists) : lists of indices
        """

        edges = []
        width, height = size[0], size[1]

        # creating projections
        vertical_projections = [[] for i in range(width)]
        horizontal_projections = [[] for i in range(height)]
        for node_index, bbox in enumerate(bboxs):
            for hp in range(bbox[0], bbox[2]):
                if hp >= width: hp = width - 1
                vertical_projections[hp].append(node_index)
            for vp in range(bbox[1], bbox[3]):
                if vp >= height: vp = height - 1
                horizontal_projections[vp].append(node_index)

        def bound(a, ori=''):
            if a < 0 : return 0
            elif ori == 'h' and a > height: return height
            elif ori == 'w' and a > width: return width
            else: return a

        for node_index, node_bbox in enumerate(bboxs):
            neighbors = [] # collect list of neighbors
            window_multiplier = 2 # how much to look around bbox
            wider = (node_bbox[2] - node_bbox[0]) > (node_bbox[3] - node_bbox[1]) # if bbox wider than taller

            ### finding neighbors ###
            while(len(neighbors) < k and window_multiplier < 100): # keep enlarging the window until at least k bboxs are found or window too big
                vertical_bboxs = []
                horizontal_bboxs = []
                neighbors = []

                if wider:
                    h_offset = int((node_bbox[2] - node_bbox[0]) * window_multiplier/4)
                    v_offset = int((node_bbox[3] - node_bbox[1]) * window_multiplier)
                else:
                    h_offset = int((node_bbox[2] - node_bbox[0]) * window_multiplier)
                    v_offset = int((node_bbox[3] - node_bbox[1]) * window_multiplier/4)

                window = [bound(node_bbox[0] - h_offset),
                        bound(node_bbox[1] - v_offset),
                        bound(node_bbox[2] + h_offset, 'w'),
                        bound(node_bbox[3] + v_offset, 'h')]

                [vertical_bboxs.extend(d) for d in vertical_projections[window[0]:window[2]]]
                [horizontal_bboxs.extend(d) for d in horizontal_projections[window[1]:window[3]]]

                for v in set(vertical_bboxs):
                    for h in set(horizontal_bboxs):
                        if v == h: neighbors.append(v)

                window_multiplier += 1 # enlarge the window

            ### finding k nearest neighbors ###
            neighbors = list(set(neighbors))
            if node_index in neighbors:
                neighbors.remove(node_index)
            neighbors_distances = [polar(node_bbox, bboxs[n])[0] for n in neighbors]
            for sd_num, sd_idx in enumerate(np.argsort(neighbors_distances)):
                if sd_num < k:
                    if [node_index, neighbors[sd_idx]] not in edges and [neighbors[sd_idx], node_index] not in edges:
                        edges.append([neighbors[sd_idx], node_index])
                        edges.append([node_index, neighbors[sd_idx]])
                else: break

        return [e[0] for e in edges], [e[1] for e in edges]

    def __fromIMG(self, paths : list):

        graphs, node_labels, edge_labels = list(), list(), list()
        features = {'paths': paths, 'texts': [], 'boxs': []}

        for path in paths:
            reader = easyocr.Reader(['en']) #! TODO: in the future, handle multilanguage!
            result = reader.readtext(path, paragraph=True)
            img = Image.open(path).convert('RGB')
            draw = ImageDraw.Draw(img)
            boxs, texts = list(), list()

            for r in result:
                box = [int(r[0][0][0]), int(r[0][0][1]), int(r[0][2][0]), int(r[0][2][1])]
                draw.rectangle(box, outline='red', width=3)
                boxs.append(box)
                texts.append(r[1])

            features['boxs'].append(boxs)
            features['texts'].append(texts)
            img.save('prova.png')

            if self.edge_type == 'fully':
                u, v = self.fully_connected(range(len(boxs)))
            elif self.edge_type == 'knn':
                u,v = self.knn_connection(Image.open(path).size, boxs)
            else:
                raise Exception('Other edge types still under development.')

            g = dgl.graph((torch.tensor(u), torch.tensor(v)), num_nodes=len(boxs), idtype=torch.int32)
            graphs.append(g)

        return graphs, node_labels, edge_labels, features

    def __fromPDF():
        #TODO: dev from PDF import of graphs
        return

    def __fromPAU(self, src: str) -> Tuple[list, list, list, list]:
        """ build graphs from Pau Riba's dataset

        Args:
            src (str) : path to where data is stored

        Returns:
            tuple (lists) : graphs, nodes and edge labels, features
        """

        graphs, node_labels, edge_labels = list(), list(), list()
        features = {'paths': [], 'texts': [], 'boxs': []}

        for image in tqdm(os.listdir(src), desc='Creating graphs'):
            if not image.endswith('tif'): continue

            img_name = image.split('.')[0]
            file_gt = img_name + '_gt.xml'
            file_ocr = img_name + '_ocr.xml'

            if not os.path.isfile(os.path.join(src, file_gt)) or not os.path.isfile(os.path.join(src, file_ocr)): continue
            features['paths'].append(os.path.join(src, image))

            # DOCUMENT REGIONS
            root = ET.parse(os.path.join(src, file_gt)).getroot()
            regions = []
            for parent in root:
                if parent.tag.split("}")[1] == 'Page':
                    for child in parent:
                        region_label = child[0].attrib['value']
                        region_bbox = [int(child[1].attrib['points'].split(" ")[0].split(",")[0].split(".")[0]),
                                    int(child[1].attrib['points'].split(" ")[1].split(",")[1].split(".")[0]),
                                    int(child[1].attrib['points'].split(" ")[2].split(",")[0].split(".")[0]),
                                    int(child[1].attrib['points'].split(" ")[3].split(",")[1].split(".")[0])]
                        regions.append([region_label, region_bbox])

            # DOCUMENT TOKENS
            root = ET.parse(os.path.join(src, file_ocr)).getroot()
            tokens_bbox = []
            tokens_text = []
            nl = []
            center = lambda rect: ((rect[2]+rect[0])/2, (rect[3]+rect[1])/2)
            for parent in root:
                if parent.tag.split("}")[1] == 'Page':
                    for child in parent:
                        if child.tag.split("}")[1] == 'TextRegion':
                            for elem in child:
                                if elem.tag.split("}")[1] == 'TextLine':
                                    for word in elem:
                                        if word.tag.split("}")[1] == 'Word':
                                            word_bbox = [int(word[0].attrib['points'].split(" ")[0].split(",")[0].split(".")[0]),
                                                        int(word[0].attrib['points'].split(" ")[1].split(",")[1].split(".")[0]),
                                                        int(word[0].attrib['points'].split(" ")[2].split(",")[0].split(".")[0]),
                                                        int(word[0].attrib['points'].split(" ")[3].split(",")[1].split(".")[0])]
                                            word_text = word[1][0].text
                                            c = center(word_bbox)
                                            for reg in regions:
                                                r = reg[1]
                                                if r[0] < c[0] < r[2] and r[1] < c[1] < r[3]:
                                                    word_label = reg[0]
                                                    break
                                            tokens_bbox.append(word_bbox)
                                            tokens_text.append(word_text)
                                            nl.append(word_label)

            features['boxs'].append(tokens_bbox)
            features['texts'].append(tokens_text)
            node_labels.append(nl)

            # getting edges
            if self.edge_type == 'fully':
                u, v = self.fully_connected(range(len(tokens_bbox)))
            elif self.edge_type == 'knn':
                u,v = self.knn_connection(Image.open(os.path.join(src, image)).size, tokens_bbox)
            else:
                raise Exception('Other edge types still under development.')

            el = list()
            for e in zip(u, v):
                if (nl[e[0]] == nl[e[1]]) and (nl[e[0]] == 'positions' or nl[e[0]] == 'total'):
                    el.append('table')
                else: el.append('none')
            edge_labels.append(el)

            g = dgl.graph((torch.tensor(u), torch.tensor(v)), num_nodes=len(tokens_bbox), idtype=torch.int32)
            graphs.append(g)

        return graphs, node_labels, edge_labels, features

    def __fromFUNSD(self, src : str) -> Tuple[list, list, list, list]:
        """Parsing FUNSD annotation files

        Args:
            src (str) : path to where data is stored

        Returns:
            tuple (lists) : graphs, nodes and edge labels, features
        """

        graphs, node_labels, edge_labels = list(), list(), list()
        features = {'paths': [], 'texts': [], 'boxs': []}
        # justOne = random.choice(os.listdir(os.path.join(src, 'adjusted_annotations'))).split(".")[0]

        if self.node_granularity[0] == 'gt':
            for file in tqdm(os.listdir(os.path.join(src, 'adjusted_annotations')), desc='Creating graphs - GT'):

                img_name = f'{file.split(".")[0]}.png'
                img_path = os.path.join(src, 'images', img_name)
                features['paths'].append(img_path)

                with open(os.path.join(src, 'adjusted_annotations', file), 'r') as f:
                    form = json.load(f)['form']

                # getting infos
                boxs, texts, ids, nl = list(), list(), list(), list()
                pair_labels = list()

                for elem in form:
                    boxs.append(elem['box'])
                    texts.append(elem['text'])
                    nl.append(elem['label'])
                    ids.append(elem['id'])
                    [pair_labels.append(pair) for pair in elem['linking']]

                for p, pair in enumerate(pair_labels):
                    pair_labels[p] = [ids.index(pair[0]), ids.index(pair[1])]

                node_labels.append(nl)
                features['texts'].append(texts)
                features['boxs'].append(boxs)

                # getting edges
                if self.edge_type[0] == 'fully':
                    u, v = self.fully_connected(range(len(boxs)))
                elif self.edge_type[0] == 'knn':
                    u,v = self.knn_connection(Image.open(img_path).size, boxs)
                else:
                    raise Exception('GraphBuilder exception: Other edge types still under development.')

                el = list()
                for e in zip(u, v):
                    edge = [e[0], e[1]]
                    if edge in pair_labels: el.append('pair')
                    else: el.append('none')
                edge_labels.append(el)

                # creating graph
                g = dgl.graph((torch.tensor(u), torch.tensor(v)), num_nodes=len(boxs), idtype=torch.int32)
                graphs.append(g)

            #! DEBUG PURPOSES TO VISUALIZE RANDOM GRAPH IMAGE FROM DATASET
            if False:
                if justOne == file.split(".")[0]:
                    print("\n\n### EXAMPLE ###")
                    print("Savin example:", img_name)

                    edge_unique_labels = np.unique(el)
                    g.edata['label'] = torch.tensor([np.where(target == edge_unique_labels)[0][0] for target in el])

                    g = self.balance_edges(g, 3, int(np.where('none' == edge_unique_labels)[0][0]))

                    img_removed = Image.open(img_path).convert('RGB')
                    draw_removed = ImageDraw.Draw(img_removed)

                    for b, box in enumerate(boxs):
                        if nl[b] == 'header':
                            color = 'yellow'
                        elif nl[b] == 'question':
                            color = 'blue'
                        elif nl[b] == 'answer':
                            color = 'green'
                        else:
                            color = 'gray'
                        draw_removed.rectangle(box, outline=color, width=3)

                    u, v = g.all_edges()
                    labels = g.edata['label'].tolist()
                    u, v = u.tolist(), v.tolist()

                    center = lambda rect: ((rect[2]+rect[0])/2, (rect[3]+rect[1])/2)

                    num_pair = 0
                    num_none = 0

                    for p, pair in enumerate(zip(u,v)):
                        sc = center(boxs[pair[0]])
                        ec = center(boxs[pair[1]])
                        if labels[p] == int(np.where('pair' == edge_unique_labels)[0][0]):
                            num_pair += 1
                            color = 'violet'
                            draw_removed.ellipse([(sc[0]-4,sc[1]-4), (sc[0]+4,sc[1]+4)], fill = 'green', outline='black')
                            draw_removed.ellipse([(ec[0]-4,ec[1]-4), (ec[0]+4,ec[1]+4)], fill = 'red', outline='black')
                        else:
                            num_none += 1
                            color='gray'
                        draw_removed.line((sc,ec), fill=color, width=3)

                    print("Balanced Links: None {} | Key-Value {}".format(num_none, num_pair))
                    img_removed.save(f'esempi/FUNSD/{img_name}_removed_graph.png')

        elif self.node_granularity[0] == 'yolo':
            path_preds = os.path.join(src, 'yolo_bbox')
            path_images = os.path.join(src, 'images')
            path_gts = os.path.join(src, 'adjusted_annotations')
            all_paths, all_preds, all_links, all_labels, all_texts = load_predictions(path_preds, path_gts, path_images)
            for f, img_path in enumerate(tqdm(all_paths, desc='Creating graphs - YOLO')):

                features['paths'].append(img_path)
                features['boxs'].append(all_preds[f])
                features['texts'].append(all_texts[f])
                node_labels.append(all_labels[f])
                pair_labels = all_links[f]

                # getting edges
                if self.edge_type[0] == 'fully':
                    u, v = self.fully_connected(range(len(features['boxs'][f])))
                elif self.edge_type[0] == 'knn':
                    u,v = self.knn_connection(Image.open(img_path).size, features['boxs'][f])
                else:
                    raise Exception('GraphBuilder exception: Other edge types still under development.')

                el = list()
                for e in zip(u, v):
                    edge = [e[0], e[1]]
                    if edge in pair_labels: el.append('pair')
                    else: el.append('none')
                edge_labels.append(el)

                # creating graph
                g = dgl.graph((torch.tensor(u), torch.tensor(v)), num_nodes=len(features['boxs'][f]), idtype=torch.int32)
                graphs.append(g)
        else:
            #TODO develop OCR too
            raise Exception('GraphBuilder Exception: only \'gt\' or \'yolo\' available for FUNSD.')


        return graphs, node_labels, edge_labels, features

In [None]:
from pathlib import Path
CONFIGS = Path("/content/drive/MyDrive/doc2graph-master/configs")
k = 3 #! try changing this value!

gb = GraphBuilder()
u, v = gb.knn_connection(image.size, boxs, k)
links = {'src': u, 'dst': v}
draw_results(image, boxs, links)
image

In [None]:
image_path = '/content/drive/MyDrive/doc2graph-master/tutorial/funsd/dataset/training_data'
image_path_test = '/content/drive/MyDrive/doc2graph-master/tutorial/funsd/dataset/testing_data'

In [None]:
graphs, node_labels, edge_labels, features = gb.get_graph(image_path, 'FUNSD')
graphs_test, node_labels_test, edge_labels_test, features_test = gb.get_graph(image_path_test, 'FUNSD')

In [None]:
graphs[0], graphs_test[0]

In [None]:
node_labels[0]

In [None]:
!pip install segmentation-models-pytorch

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from segmentation_models_pytorch.base import modules as md


class DecoderBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            skip_channels,
            out_channels,
            use_batchnorm=True,
            attention_type=None,
    ):
        super().__init__()
        self.conv1 = md.Conv2dReLU(
            in_channels + skip_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels)
        self.conv2 = md.Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.attention2 = md.Attention(attention_type, in_channels=out_channels)

    def forward(self, x, skip=None):
        if skip is not None:
            x = F.interpolate(x, size=skip.shape[2:], mode="nearest")
            x = torch.cat([x, skip], dim=1)
            x = self.attention1(x)
        else:
            x = F.interpolate(x, scale_factor=2, mode="nearest")
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.attention2(x)
        return x


class CenterBlock(nn.Sequential):
    def __init__(self, in_channels, out_channels, use_batchnorm=True):
        conv1 = md.Conv2dReLU(
            in_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        conv2 = md.Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        super().__init__(conv1, conv2)


class UnetDecoder(nn.Module):
    def __init__(
            self,
            encoder_channels,
            decoder_channels,
            n_blocks=5,
            use_batchnorm=True,
            attention_type=None,
            center=False,
    ):
        super().__init__()

        if n_blocks != len(decoder_channels):
            raise ValueError(
                "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
                    n_blocks, len(decoder_channels)
                )
            )

        encoder_channels = encoder_channels[1:]  # remove first skip with same spatial resolution
        encoder_channels = encoder_channels[::-1]  # reverse channels to start from head of encoder

        # computing blocks input and output channels
        head_channels = encoder_channels[0]
        in_channels = [head_channels] + list(decoder_channels[:-1])
        skip_channels = list(encoder_channels[1:]) + [0]
        out_channels = decoder_channels

        if center:
            self.center = CenterBlock(
                head_channels, head_channels, use_batchnorm=use_batchnorm
            )
        else:
            self.center = nn.Identity()

        # combine decoder keyword arguments
        kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
        blocks = [
            DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
            for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
        ]
        self.blocks = nn.ModuleList(blocks)

    def forward(self, *features):

        features = features[1:]    # remove first skip with same spatial resolution
        features = features[::-1]  # reverse channels to start from head of encoder

        head = features[0]
        skips = features[1:]

        x = self.center(head)
        for i, decoder_block in enumerate(self.blocks):
            skip = skips[i] if i < len(skips) else None
            x = decoder_block(x, skip)

        return x

In [None]:
from typing import Optional, Union, List
from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import SegmentationModel
from segmentation_models_pytorch.base import SegmentationHead, ClassificationHead


class Unet(SegmentationModel):
    """Unet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder*
    and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial
    resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *concatenation*
    for fusing decoder blocks with skip connections.

    Args:
        encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
            to extract features of different spatial resolution
        encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
            two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
            with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
            Default is 5
        encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
            other pretrained weights (see table with available weights for each encoder_name)
        decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder.
            Length of the list should be the same as **encoder_depth**
        decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers
            is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption.
            Available options are **True, False, "inplace"**
        decoder_attention_type: Attention module used in decoder of the model. Available options are **None** and **scse**.
            SCSE paper - https://arxiv.org/abs/1808.08127
        in_channels: A number of input channels for the model, default is 3 (RGB images)
        classes: A number of classes for output mask (or you can think as a number of channels of output mask)
        activation: An activation function to apply after the final convolution layer.
            Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**.
            Default is **None**
        aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
            on top of encoder if **aux_params** is not **None** (default). Supported params:
                - classes (int): A number of classes
                - pooling (str): One of "max", "avg". Default is "avg"
                - dropout (float): Dropout factor in [0, 1)
                - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits)

    Returns:
        ``torch.nn.Module``: Unet

    .. _Unet:
        https://arxiv.org/abs/1505.04597

    """

    def __init__(
        self,
        encoder_name: str = "resnet34",
        encoder_depth: int = 5,
        encoder_weights: Optional[str] = "imagenet",
        decoder_use_batchnorm: bool = True,
        decoder_channels: List[int] = (256, 128, 64, 32, 16),
        decoder_attention_type: Optional[str] = None,
        in_channels: int = 3,
        classes: int = 1,
        activation: Optional[Union[str, callable]] = None,
        aux_params: Optional[dict] = None,
    ):
        super().__init__()

        self.encoder = get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights,
        )

        self.decoder = UnetDecoder(
            encoder_channels=self.encoder.out_channels,
            decoder_channels=decoder_channels,
            n_blocks=encoder_depth,
            use_batchnorm=decoder_use_batchnorm,
            center=True if encoder_name.startswith("vgg") else False,
            attention_type=decoder_attention_type,
        )

        self.segmentation_head = SegmentationHead(
            in_channels=decoder_channels[-1],
            out_channels=classes,
            activation=activation,
            kernel_size=3,
        )

        if aux_params is not None:
            self.classification_head = ClassificationHead(
                in_channels=self.encoder.out_channels[-1], **aux_params
            )
        else:
            self.classification_head = None

        self.name = "u-{}".format(encoder_name)
        self.initialize()

In [None]:
#!unzip /content/drive/MyDrive/doc2graph-master/tutorial/checkpoints.zip -d /content/drive/MyDrive/doc2graph-master/tutorial/funsd

In [None]:
import random
from typing import Tuple
import spacy
import torch
import torchvision
from tqdm import tqdm
from PIL import Image, ImageDraw
import torchvision.transforms.functional as tvF
CHECKPOINTS = Path('/content/drive/MyDrive/doc2graph-master/tutorial/funsd/checkpoints')
class FeatureBuilder():

    def __init__(self, d : int = 'cpu'):
        """FeatureBuilder constructor

        Args:
            d (int): device number, if any (cpu or cuda:n)
        """
        self.cfg_preprocessing = get_config('base')
        self.device = d
        self.add_geom = self.cfg_preprocessing.FEATURES.add_geom
        self.add_embs = self.cfg_preprocessing.FEATURES.add_embs
        self.add_hist = self.cfg_preprocessing.FEATURES.add_hist
        self.add_visual = self.cfg_preprocessing.FEATURES.add_visual
        self.add_eweights = self.cfg_preprocessing.FEATURES.add_eweights
        self.add_fudge = self.cfg_preprocessing.FEATURES.add_fudge
        self.num_polar_bins = self.cfg_preprocessing.FEATURES.num_polar_bins

        if self.add_embs:
            self.text_embedder = spacy.load('en_core_web_lg')

        if self.add_visual:
            self.visual_embedder = Unet(encoder_name="mobilenet_v2", encoder_weights=None, in_channels=1, classes=4)
            self.visual_embedder.load_state_dict(torch.load(CHECKPOINTS / 'backbone_unet.pth')['weights'])
            self.visual_embedder = self.visual_embedder.encoder
            self.visual_embedder.to(d)

        self.sg = lambda rect, s : [rect[0]/s[0], rect[1]/s[1], rect[2]/s[0], rect[3]/s[1]] # scaling by img width and height

    def add_features(self, graphs : list, features : list) -> Tuple[list, int]:
        """ Add features to provided graphs

        Args:
            graphs (list) : list of DGLGraphs
            features (list) : list of features "sources", like text, positions and images

        Returns:
            chunks list and its lenght
        """

        for id, g in enumerate(tqdm(graphs, desc='adding features')):

            # positional features
            size = Image.open(features['paths'][id]).size
            feats = [[] for _ in range(len(features['boxs'][id]))]
            geom = [self.sg(box, size) for box in features['boxs'][id]]
            chunks = []

            # 'geometrical' features
            if self.add_geom:

                # TODO add 2d encoding like "LayoutLM*"
                [feats[idx].extend(self.sg(box, size)) for idx, box in enumerate(features['boxs'][id])]
                chunks.append(4)

            # HISTOGRAM OF TEXT
            if self.add_hist:

                [feats[idx].extend(hist) for idx, hist in enumerate(get_histogram(features['texts'][id]))]
                chunks.append(4)

            # textual features
            if self.add_embs:

                # LANGUAGE MODEL (SPACY)
                [feats[idx].extend(self.text_embedder(features['texts'][id][idx]).vector) for idx, _ in enumerate(feats)]
                chunks.append(len(self.text_embedder(features['texts'][id][0]).vector))

            # visual features
            # https://pytorch.org/vision/stable/generated/torchvision.ops.roi_align.html?highlight=roi
            if self.add_visual:
                img = Image.open(features['paths'][id])
                visual_emb = self.visual_embedder(tvF.to_tensor(img).unsqueeze_(0).to(self.device)) # output [batch, channels, dim1, dim2]
                bboxs = [torch.Tensor(b) for b in features['boxs'][id]]
                bboxs = [torch.stack(bboxs, dim=0).to(self.device)]
                h = [torchvision.ops.roi_align(input=ve, boxes=bboxs, spatial_scale=1/ min(size[1] / ve.shape[2] , size[0] / ve.shape[3]), output_size=1) for ve in visual_emb[1:]]
                h = torch.cat(h, dim=1)

                # VISUAL FEATURES (RESNET-IMAGENET)
                [feats[idx].extend(torch.flatten(h[idx]).tolist()) for idx, _ in enumerate(feats)]
                chunks.append(len(torch.flatten(h[0]).tolist()))

            if self.add_eweights:
                u, v = g.edges()
                srcs, dsts =  u.tolist(), v.tolist()
                distances = []
                angles = []

                # TODO CHOOSE WHICH DISTANCE NORMALIZATION TO APPLY
                #! with fully connected simply normalized with max distance between distances
                # m = sqrt((size[0]*size[0] + size[1]*size[1]))
                # parable = lambda x : (-x+1)**4

                for pair in zip(srcs, dsts):
                    dist, angle = polar(features['boxs'][id][pair[0]], features['boxs'][id][pair[1]])
                    distances.append(dist)
                    angles.append(angle)

                m = max(distances)
                polar_coordinates = to_bin(distances, angles, self.num_polar_bins)
                g.edata['feat'] = polar_coordinates

            else:
                distances = ([0.0 for _ in range(g.number_of_edges())])
                m = 1

            g.ndata['geom'] = torch.tensor(geom, dtype=torch.float32)
            g.ndata['feat'] = torch.tensor(feats, dtype=torch.float32)

            distances = torch.tensor([(1-d/m) for d in distances], dtype=torch.float32)
            tresh_dist = torch.where(distances > 0.9, torch.full_like(distances, 0.1), torch.zeros_like(distances))
            g.edata['weights'] = tresh_dist

            norm = []
            num_nodes = len(features['boxs'][id]) - 1
            for n in range(num_nodes + 1):
                neigs = torch.count_nonzero(tresh_dist[n*num_nodes:(n+1)*num_nodes]).tolist()
                try: norm.append([1. / neigs])
                except: norm.append([1.])
            g.ndata['norm'] = torch.tensor(norm, dtype=torch.float32)

            #! DEBUG PURPOSES TO VISUALIZE RANDOM GRAPH IMAGE FROM DATASET
            if False:
                if id == rand_id and self.add_eweights:
                    print("\n\n### EXAMPLE ###")

                    img_path = features['paths'][id]
                    img = Image.open(img_path).convert('RGB')
                    draw = ImageDraw.Draw(img)

                    center = lambda rect: ((rect[2]+rect[0])/2, (rect[3]+rect[1])/2)
                    select = [random.randint(0, len(srcs)) for _ in range(10)]
                    for p, pair in enumerate(zip(srcs, dsts)):
                        if p in select:
                            sc = center(features['boxs'][id][pair[0]])
                            ec = center(features['boxs'][id][pair[1]])
                            draw.line((sc, ec), fill='grey', width=3)
                            middle_point = ((sc[0] + ec[0])/2,(sc[1] + ec[1])/2)
                            draw.text(middle_point, str(angles[p]), fill='black')
                            draw.rectangle(features['boxs'][id][pair[0]], fill='red')
                            draw.rectangle(features['boxs'][id][pair[1]], fill='blue')

                    img.save(f'esempi/FUNSD/edges.png')

        return chunks, len(chunks)

    def get_info(self):
        print(f"-> textual feats: {self.add_embs}\n-> visual feats: {self.add_visual}\n-> edge feats: {self.add_eweights}")

In [None]:
!python -m spacy download en_core_web_lg

In [None]:
device = 'cuda:0'
fb = FeatureBuilder(d=device)
chunks, _ = fb.add_features(graphs, features) # chunks is used by the model to merge different embeddings together!

In [None]:
chunks_test, _ = fb.add_features(graphs_test, features_test)

In [None]:
graphs

In [None]:
node_unique_labels = np.unique(np.array([l for nl in node_labels for l in nl]))
node_num_classes = len(node_unique_labels)
node_num_features = graphs[0].ndata['feat'].shape[1]

for idx, labels in enumerate(node_labels):
  graphs[idx].ndata['label'] = torch.tensor([np.where(target == node_unique_labels)[0][0] for target in labels], dtype=torch.int64)

In [None]:
node_unique_labels_test = np.unique(np.array([l for nl in node_labels_test for l in nl]))
node_num_classes_test = len(node_unique_labels_test)
node_num_features_test = graphs_test[0].ndata['feat'].shape[1]

for idx, labels in enumerate(node_labels_test):
  graphs_test[idx].ndata['label'] = torch.tensor([np.where(target == node_unique_labels_test)[0][0] for target in labels], dtype=torch.int64)

In [None]:
edge_unique_labels = np.unique(edge_labels[0])
edge_num_classes = len(edge_unique_labels)
try:
  edge_num_features = graphs[0].edata['feat'].shape[1]
except:
  edge_num_features = 0

for idx, labels in enumerate(edge_labels):
  graphs[idx].edata['label'] = torch.tensor([np.where(target == edge_unique_labels)[0][0] for target in labels], dtype=torch.int64)

In [None]:
edge_unique_labels_test = np.unique(edge_labels_test[0])
edge_num_classes_test = len(edge_unique_labels_test)
try:
  edge_num_features_test = graphs_test[0].edata['feat'].shape[1]
except:
  edge_num_features_test = 0

for idx, labels in enumerate(edge_labels_test):
  graphs_test[idx].edata['label'] = torch.tensor([np.where(target == edge_unique_labels_test)[0][0] for target in labels], dtype=torch.int64)


In [None]:
graphs[0]

In [None]:
!pip install torch-geometric

In [None]:
import torch
from torch_geometric.data import Data

def dgl_to_pyg(graph):
    num_nodes = graph.num_nodes()
    num_edges = graph.num_edges()
    x = graph.ndata['feat']
    edge_attr = graph.edata['feat']
    src, dst = graph.edges()
    edge_index = torch.stack([src, dst], dim=0)
    y = graph.ndata['label']
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)

    return data
pyg_graphs = []
for dgl_graph in graphs:
    pyg_graph = dgl_to_pyg(dgl_graph)
    pyg_graphs.append(pyg_graph)

In [None]:
pyg_graphs_test = []
for dgl_graph in graphs_test:
    pyg_graph = dgl_to_pyg(dgl_graph)
    pyg_graphs_test.append(pyg_graph)

In [None]:
pyg_graphs

# ConGAT model

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GATConv

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_heads, dropout):
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=num_heads)
        self.dropout1 = torch.nn.Dropout(p=dropout)
        self.conv2 = GATConv(hidden_channels * num_heads, hidden_channels, heads=num_heads)
        self.dropout2 = torch.nn.Dropout(p=dropout)
        self.conv3 = GATConv(hidden_channels * num_heads, hidden_channels, heads=num_heads)
        self.dropout3 = torch.nn.Dropout(p=dropout)
        self.conv4 = GATConv(hidden_channels * num_heads, out_channels, heads=num_heads)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = F.elu(self.conv1(x, edge_index))
        x = self.dropout1(x)
        x = F.elu(self.conv2(x, edge_index))
        x = self.dropout2(x)
        x = F.elu(self.conv3(x, edge_index))
        x = self.dropout3(x)
        x = F.elu(self.conv4(x, edge_index))

        return x

loader = DataLoader(pyg_graphs, batch_size=1, shuffle=True)
model = GAT(in_channels=1752, hidden_channels=8, out_channels=4, num_heads=4, dropout=0.2)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.9)

In [None]:
model.train()
for epoch in range(2000):
    for data in loader:
        optimizer.zero_grad()
        out = model(data)
        loss = F.cross_entropy(out, data.y)  # Loss function được thay đổi để phù hợp với số lớp
        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch}, Loss: {loss.item()}')

In [None]:
loader_test = DataLoader(pyg_graphs_test, batch_size=1, shuffle=True)

In [None]:
import torch
from sklearn.metrics import accuracy_score, f1_score

def evaluate_model(model, data_loader):
    model.eval()
    predictions = []
    targets = []

    with torch.no_grad():
        for data in data_loader:
            out = model(data)
            _, predicted = torch.max(out, 1)
            predictions.extend(predicted.tolist())
            targets.extend(data.y.tolist())

    accuracy = accuracy_score(targets, predictions)
    f1 = f1_score(targets, predictions, average='weighted')

    return accuracy, f1

accuracy, f1 = evaluate_model(model, loader_test)
print(f'Accuracy: {accuracy}, F1 Score: {f1}')