In [1]:
import torch
from torch import nn
from torch.utils import data
from tqdm import tqdm

from pathlib import Path
from utils import xml_parser, Neighbour, visualizer, candidate
from utils import operations as op
from utils import preprocess
from network import dataset

In [2]:
PAD = 0

def get_neighbours(list_of_neighbours, vocabulary, n_neighbours):
    """Returns a list of neighbours and coordinates."""
    neighbours = list()
    
    for neighbour in list_of_neighbours:
        if neighbour['text'] not in vocabulary:
            vocabulary[neighbour['text']] = len(vocabulary)

        neighbours.extend(
            [
                vocabulary[neighbour['text']],
                neighbour['x'],
                neighbour['y']
            ]
        )
    
    len_neighbours = int(len(neighbours) / 3) 
    if len_neighbours != n_neighbours:
        if  len_neighbours > n_neighbours:
            neighbours = neighbours[:(n_neighbours * 3)]
        else:
            neighbours.extend(['<PAD>', 0., 0.] * (n_neighbours - len_neighbours))

    return neighbours

def parse_input(annotations, fields_dict, n_neighbours=5, vocabulary=None):
    """Generates input samples from annotations data."""

    x = list()
    Y = list()
    if not vocabulary:
        vocabulary = { '<PAD>':PAD }

    for annotation in tqdm(annotations, desc='Parsing Input'):
        
        fields = annotation['fields']
        
        for field in fields:
            if fields[field]['true_candidates']:
                Y.append(1.)
                neighbours = get_neighbours(
                    fields[field]['true_candidates'][0]['neighbours'],
                    vocabulary, n_neighbours
                )
                x.append(
                    [
                        fields_dict[field],
                        fields[field]['true_candidates'][0]['x'],
                        fields[field]['true_candidates'][0]['y']
                    ] + neighbours)

                for candidate in fields[field]['other_candidates']:

                    Y.append(0.)
                    neighbours = get_neighbours(candidate['neighbours'], vocabulary, n_neighbours)
                    x.append(
                        [
                            fields_dict[field],
                            candidate['x'],
                            candidate['y'],
                        ] + neighbours)

    return x, Y, vocabulary

In [4]:
vocab = {'<PAD>':0}

In [5]:
_ = get_neighbours(annotation[1]['fields']['invoice_date']['true_candidates'][0]['neighbours'], vocab, 5)

NameError: name 'annotation' is not defined

In [164]:
len(vocab)

17

In [9]:
field_dict = {'invoice_date':0, 'invoice_no':1, 'total':2}

In [166]:
x, Y, vocab = parse_input(annotation, field_dict)

Parsing Input: 100%|██████████| 505/505 [00:00<00:00, 705.20it/s]


In [17]:
class DocumentsDataset(data.Dataset):
    """Stores the annotated documents dataset."""
    
    def __init__(self, xmls_path, ocr_path, image_path,candidate_path,
                 field_dict, n_neighbour=5, vocab=None):
        """ Initialize the dataset with preprocessing """
        annotation, classes_count, class_mapping = xml_parser.get_data(xmls_path)
        annotation = candidate.attach_candidate(annotation, candidate_path)
        annotation = Neighbour.attach_neighbour(annotation, ocr_path)
        annotation = op.normalize_positions(annotation)
        self.features, self.labels, self.vocab = preprocess.parse_input(annotation, field_dict, n_neighbour, vocab)        
    
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, idx):
        
        return torch.tensor(self.features[idx]), self.labels[idx]

In [18]:
this_dir = Path.cwd()
xmls_path = this_dir / "dataset" / "xmls"
ocr_path = this_dir / "dataset" / "tesseract_results_lstm"
image_path = this_dir / "dataset" / "images"
candidate_path = this_dir / "dataset" / "candidates"

In [32]:
datas = dataset.DocumentsDataset(xmls_path, ocr_path, image_path, candidate_path, field_dict)

Reading Annotations: 505it [00:00, 803.05it/s]
Attaching Candidate: 100%|██████████| 505/505 [00:01<00:00, 439.85it/s]
Attaching Neighbours: 100%|██████████| 505/505 [01:17<00:00,  4.56it/s]
normalizing position coordinates: 100%|██████████| 505/505 [00:01<00:00, 275.64it/s]
Parsing Input: 100%|██████████| 505/505 [00:00<00:00, 819.88it/s]


In [33]:
datas[0]

(tensor([ 1.0000,  0.6985,  0.0865,  1.0000,  0.0176, -0.0512,  2.0000,  0.0472,
         -0.0512,  3.0000,  0.0936, -0.0511,  4.0000,  0.1104, -0.0143,  5.0000,
         -0.0133, -0.0145]), 1.0)

In [39]:
datal = data.DataLoader(datas, batch_size=32)

In [40]:
x, Y = next(iter(datal))

In [42]:
x[1]

tensor([ 1.0000e+00,  2.2400e-01,  9.3091e-02,  2.2000e+01, -9.0941e-02,
        -5.6909e-02,  2.3000e+01, -3.0824e-02, -5.7909e-02,  2.4000e+01,
         1.1176e-02, -1.4909e-02,  2.5000e+01,  2.5647e-02,  9.0909e-05,
         2.6000e+01,  6.1176e-02,  0.0000e+00])