In [1]:
import torch
from torch import nn
from torch.utils import data
from tqdm import tqdm

from pathlib import Path
from utils import xml_parser, Neighbour, visualizer, candidate
from utils import operations as op
from utils import preprocess
from network import dataset

In [150]:
PAD = 0

def get_neighbours(list_of_neighbours, vocabulary, n_neighbours):
    """Returns a list of neighbours and coordinates."""
    neighbours = list()
    neighbour_cords = list()
    
    for neighbour in list_of_neighbours:
        if neighbour['text'] not in vocabulary:
            vocabulary[neighbour['text']] = len(vocabulary)

        neighbours.append(vocabulary[neighbour['text']])
        neighbour_cords.append(
            [
                neighbour['x'],
                neighbour['y']
            ]
        )
    
    len_neighbours = len(neighbours)
    if len_neighbours != n_neighbours:
        if  len_neighbours > n_neighbours:
            neighbours = neighbours[:n_neighbours]
            neighbour_cords = neighbour_cords[:n_neighbours]
        else:
            neighbours.append(vocabulary['<PAD>'])
            neighbour_cords.extend([[0., 0.]] * (n_neighbours - len_neighbours))

    return neighbours, neighbour_cords

def parse_input(annotations, fields_dict, n_neighbours=5, vocabulary=None):
    """Generates input samples from annotations data."""
    
    field_ids = list()
    candidate_cords = list()
    neighbours = list()
    neighbour_cords = list()
    labels = list()
    if not vocabulary:
        vocabulary = { '<PAD>':PAD }

    for annotation in tqdm(annotations, desc='Parsing Input'):
        
        fields = annotation['fields']
        
        for field in fields:
            if fields[field]['true_candidates']:
                _neighbours, _neighbour_cords = get_neighbours(
                    fields[field]['true_candidates'][0]['neighbours'],
                    vocabulary, n_neighbours
                )
                labels.append(1.)
                field_ids.append(fields_dict[field])
                candidate_cords.append(
                    [
                        fields[field]['true_candidates'][0]['x'],
                        fields[field]['true_candidates'][0]['y']
                    ]
                )
                neighbours.append(_neighbours)
                neighbour_cords.append(_neighbour_cords)
               
                for candidate in fields[field]['other_candidates']:

                    _neighbours, _neighbour_cords = get_neighbours(candidate['neighbours'], vocabulary, n_neighbours)
                    labels.append(0.)
                    field_ids.append(fields_dict[field])
                    candidate_cords.append(
                        [
                            candidate['x'],
                            candidate['y']
                        ]
                    )
                    neighbours.append(_neighbours)
                    neighbour_cords.append(_neighbour_cords)
                    
                    
    return field_ids, candidate_cords, neighbours, neighbour_cords, labels, vocabulary

In [5]:
vocab = {'<PAD>':0}

In [152]:
len(vocab)

1

In [4]:
field_dict = {'invoice_date':0, 'invoice_no':1, 'total':2}

In [173]:
class DocumentsDataset(data.Dataset):
    """Stores the annotated documents dataset."""
    
    def __init__(self, xmls_path, ocr_path, image_path,candidate_path,
                 field_dict, n_neighbour=5, vocab=None):
        """ Initialize the dataset with preprocessing """
        annotation, classes_count, class_mapping = xml_parser.get_data(xmls_path)
        annotation = candidate.attach_candidate(annotation, candidate_path)
        annotation = Neighbour.attach_neighbour(annotation, ocr_path)
        annotation = op.normalize_positions(annotation)
        _data = parse_input(annotation, field_dict, n_neighbour, vocab)
        self.field_ids, self.candidate_cords, self.neighbours, self.neighbour_cords, self.labels, self.vocab = _data
    
    def __len__(self):
        return len(self.field_ids)
    
    def __getitem__(self, idx):
        
        return (
            torch.tensor(self.field_ids[idx]),
            torch.tensor(self.candidate_cords[idx]),
            torch.tensor(self.neighbours[idx]),
            torch.tensor(self.neighbour_cords[idx]),
            self.labels[idx]
        )

In [2]:
this_dir = Path.cwd()
xmls_path = this_dir / "dataset" / "xmls"
ocr_path = this_dir / "dataset" / "tesseract_results_lstm"
image_path = this_dir / "dataset" / "images"
candidate_path = this_dir / "dataset" / "candidates"

In [3]:
annotation, classes_count, class_mapping = xml_parser.get_data(xmls_path)
annotation = candidate.attach_candidate(annotation, candidate_path)
annotation, vocab = Neighbour.attach_neighbour(annotation, ocr_path)
annotation = op.normalize_positions(annotation)

Reading Annotations: 505it [00:00, 603.19it/s]
Attaching Candidate: 100%|██████████| 505/505 [00:01<00:00, 422.56it/s]
Attaching Neighbours: 100%|██████████| 505/505 [01:18<00:00,  4.74it/s]
normalizing position coordinates:   5%|▌         | 26/505 [00:00<00:01, 259.73it/s]

Vocabulary of size 515 built!


normalizing position coordinates: 100%|██████████| 505/505 [00:02<00:00, 246.28it/s]


In [8]:
field, c_co, ns, n_cos, ls, v=preprocess.parse_input(annotation[:100], class_mapping, 5, vocab)

Parsing Input: 100%|██████████| 100/100 [00:00<00:00, 368.67it/s]


In [9]:
field

[array([1., 0., 0.]),
 array([1., 0., 0.]),
 array([1., 0., 0.]),
 array([1., 0., 0.]),
 array([1., 0., 0.]),
 array([1., 0., 0.]),
 array([0., 1., 0.]),
 array([0., 1., 0.]),
 array([0., 1., 0.]),
 array([0., 1., 0.]),
 array([0., 1., 0.]),
 array([0., 1., 0.]),
 array([0., 1., 0.]),
 array([0., 1., 0.]),
 array([0., 1., 0.]),
 array([0., 1., 0.]),
 array([0., 1., 0.]),
 array([0., 1., 0.]),
 array([0., 1., 0.]),
 array([0., 1., 0.]),
 array([0., 1., 0.]),
 array([0., 1., 0.]),
 array([0., 1., 0.]),
 array([0., 1., 0.]),
 array([0., 1., 0.]),
 array([0., 1., 0.]),
 array([0., 0., 1.]),
 array([0., 0., 1.]),
 array([0., 0., 1.]),
 array([0., 0., 1.]),
 array([0., 0., 1.]),
 array([0., 0., 1.]),
 array([0., 0., 1.]),
 array([0., 0., 1.]),
 array([0., 0., 1.]),
 array([0., 0., 1.]),
 array([0., 0., 1.]),
 array([0., 0., 1.]),
 array([0., 0., 1.]),
 array([0., 0., 1.]),
 array([1., 0., 0.]),
 array([1., 0., 0.]),
 array([1., 0., 0.]),
 array([1., 0., 0.]),
 array([1., 0., 0.]),
 array([1.

In [158]:
n_cos[1]

[[-0.09094117647058825, -0.0569090909090909],
 [-0.030823529411764722, -0.0579090909090909],
 [0.011176470588235288, -0.0149090909090909],
 [0.02564705882352941, 9.090909090909982e-05],
 [0.061176470588235304, 0.0]]

In [5]:
datas = dataset.DocumentsDataset(xmls_path, ocr_path, image_path, candidate_path)

Reading Annotations: 505it [00:00, 678.08it/s]
Attaching Candidate:   9%|▉         | 45/505 [00:00<00:01, 443.71it/s]

Class Mapping: {'invoice_no': 0, 'invoice_date': 1, 'total': 2}
Classs counts: {'invoice_no': 453, 'invoice_date': 459, 'total': 191}


Attaching Candidate: 100%|██████████| 505/505 [00:01<00:00, 445.60it/s]
Attaching Neighbours: 100%|██████████| 505/505 [01:18<00:00,  4.70it/s]
normalizing position coordinates:   9%|▉         | 46/505 [00:00<00:02, 227.05it/s]

Vocabulary of size 515 built!


normalizing position coordinates: 100%|██████████| 505/505 [00:02<00:00, 227.51it/s]
Parsing Input: 100%|██████████| 505/505 [00:01<00:00, 324.20it/s]


In [18]:
datas.vocab

{'<PAD>': 0,
 '<NUMBER>': 1,
 '<RARE>': 2,
 'date': 4,
 'rate': 5,
 'nm': 6,
 'end': 7,
 'start': 8,
 ':30': 9,
 'week:': 10,
 'weekdays': 11,
 ' ': 12,
 'to': 13,
 'of': 14,
 'spots/week': 15,
 'the': 16,
 'time': 17,
 '-': 18,
 'news': 19,
 'and': 20,
 '=': 21,
 'n': 22,
 'or': 23,
 'by': 24,
 'agency': 25,
 'm-f': 26,
 '/': 27,
 'rating': 28,
 'advertiser': 29,
 'contract': 30,
 '|': 31,
 'not': 32,
 '#': 33,
 'station': 34,
 'payment': 35,
 'is': 36,
 '  ': 37,
 'on': 38,
 'for': 39,
 'in': 40,
 'notice': 41,
 'type': 42,
 'advertising': 43,
 'amount': 44,
 'a': 45,
 'order': 46,
 'pm': 47,
 'spots': 48,
 'with': 49,
 '09/17/12': 50,
 'pre-emptible': 51,
 'am': 52,
 '09/10/12': 53,
 'ch': 54,
 'any': 55,
 'class': 56,
 'start/end': 57,
 'length': 58,
 'description': 59,
 'this': 60,
 '09/24/12': 61,
 'wavy': 62,
 'cm': 63,
 'air': 64,
 '~~': 65,
 'week': 66,
 'product': 67,
 '10/22/12': 68,
 'at': 69,
 'shall': 70,
 '09/16/12': 71,
 'revision': 72,
 'page': 73,
 'days': 74,
 's': 7

In [17]:
datas[5]

(tensor([1., 0., 0.]),
 tensor([0.6981, 0.0869]),
 tensor([113,   4, 306, 129,  30]),
 tensor([[ 0.0180, -0.0515],
         [ 0.0475, -0.0515],
         [ 0.0940, -0.0515],
         [ 0.1107, -0.0146],
         [-0.0129, -0.0148]]),
 tensor([0.]))

In [161]:
datal = data.DataLoader(datas, batch_size=2)

In [162]:
field_ids, c_cords, neighs, neigh_cords, labels = next(iter(datal))

In [67]:
f =(2,3,4)
a, b, c = f

In [70]:
c

4

In [None]:
X y 
y = [B, 1]
X = [B, (3 + N * 3)]


y words cords

words [B, N]
cords [B, N, 2]

words [B, N] = [B, N, D]
cords [B, N, 2 ] = [B, N, D]