In [1]:
import torch
from torch import nn
from torch.utils import data
from tqdm import tqdm

from pathlib import Path
from utils import xml_parser, Neighbour, visualizer, candidate
from utils import operations as op
from utils import preprocess
from network import dataset

In [150]:
PAD = 0

def get_neighbours(list_of_neighbours, vocabulary, n_neighbours):
    """Returns a list of neighbours and coordinates."""
    neighbours = list()
    neighbour_cords = list()
    
    for neighbour in list_of_neighbours:
        if neighbour['text'] not in vocabulary:
            vocabulary[neighbour['text']] = len(vocabulary)

        neighbours.append(vocabulary[neighbour['text']])
        neighbour_cords.append(
            [
                neighbour['x'],
                neighbour['y']
            ]
        )
    
    len_neighbours = len(neighbours)
    if len_neighbours != n_neighbours:
        if  len_neighbours > n_neighbours:
            neighbours = neighbours[:n_neighbours]
            neighbour_cords = neighbour_cords[:n_neighbours]
        else:
            neighbours.append(vocabulary['<PAD>'])
            neighbour_cords.extend([[0., 0.]] * (n_neighbours - len_neighbours))

    return neighbours, neighbour_cords

def parse_input(annotations, fields_dict, n_neighbours=5, vocabulary=None):
    """Generates input samples from annotations data."""
    
    field_ids = list()
    candidate_cords = list()
    neighbours = list()
    neighbour_cords = list()
    labels = list()
    if not vocabulary:
        vocabulary = { '<PAD>':PAD }

    for annotation in tqdm(annotations, desc='Parsing Input'):
        
        fields = annotation['fields']
        
        for field in fields:
            if fields[field]['true_candidates']:
                _neighbours, _neighbour_cords = get_neighbours(
                    fields[field]['true_candidates'][0]['neighbours'],
                    vocabulary, n_neighbours
                )
                labels.append(1.)
                field_ids.append(fields_dict[field])
                candidate_cords.append(
                    [
                        fields[field]['true_candidates'][0]['x'],
                        fields[field]['true_candidates'][0]['y']
                    ]
                )
                neighbours.append(_neighbours)
                neighbour_cords.append(_neighbour_cords)
               
                for candidate in fields[field]['other_candidates']:

                    _neighbours, _neighbour_cords = get_neighbours(candidate['neighbours'], vocabulary, n_neighbours)
                    labels.append(0.)
                    field_ids.append(fields_dict[field])
                    candidate_cords.append(
                        [
                            candidate['x'],
                            candidate['y']
                        ]
                    )
                    neighbours.append(_neighbours)
                    neighbour_cords.append(_neighbour_cords)
                    
                    
    return field_ids, candidate_cords, neighbours, neighbour_cords, labels, vocabulary

In [5]:
vocab = {'<PAD>':0}

In [152]:
len(vocab)

1

In [4]:
field_dict = {'invoice_date':0, 'invoice_no':1, 'total':2}

In [173]:
class DocumentsDataset(data.Dataset):
    """Stores the annotated documents dataset."""
    
    def __init__(self, xmls_path, ocr_path, image_path,candidate_path,
                 field_dict, n_neighbour=5, vocab=None):
        """ Initialize the dataset with preprocessing """
        annotation, classes_count, class_mapping = xml_parser.get_data(xmls_path)
        annotation = candidate.attach_candidate(annotation, candidate_path)
        annotation = Neighbour.attach_neighbour(annotation, ocr_path)
        annotation = op.normalize_positions(annotation)
        _data = parse_input(annotation, field_dict, n_neighbour, vocab)
        self.field_ids, self.candidate_cords, self.neighbours, self.neighbour_cords, self.labels, self.vocab = _data
    
    def __len__(self):
        return len(self.field_ids)
    
    def __getitem__(self, idx):
        
        return (
            torch.tensor(self.field_ids[idx]),
            torch.tensor(self.candidate_cords[idx]),
            torch.tensor(self.neighbours[idx]),
            torch.tensor(self.neighbour_cords[idx]),
            self.labels[idx]
        )

In [2]:
this_dir = Path.cwd()
xmls_path = this_dir / "dataset" / "xmls"
ocr_path = this_dir / "dataset" / "tesseract_results_lstm"
image_path = this_dir / "dataset" / "images"
candidate_path = this_dir / "dataset" / "candidates"

In [156]:
annotation, classes_count, class_mapping = xml_parser.get_data(xmls_path)
annotation = candidate.attach_candidate(annotation, candidate_path)
annotation = Neighbour.attach_neighbour(annotation, ocr_path)
annotation = op.normalize_positions(annotation)

Reading Annotations: 505it [00:00, 576.57it/s]
Attaching Candidate: 100%|██████████| 505/505 [00:01<00:00, 461.50it/s]
Attaching Neighbours: 100%|██████████| 505/505 [01:17<00:00,  4.75it/s]
normalizing position coordinates: 100%|██████████| 505/505 [00:01<00:00, 272.11it/s]


In [157]:
field, c_co, ns, n_cos, ls, v=parse_input(annotation[:1], field_dict, 5, vocab)

Parsing Input: 100%|██████████| 1/1 [00:00<00:00, 167.00it/s]


In [158]:
n_cos[1]

[[-0.09094117647058825, -0.0569090909090909],
 [-0.030823529411764722, -0.0579090909090909],
 [0.011176470588235288, -0.0149090909090909],
 [0.02564705882352941, 9.090909090909982e-05],
 [0.061176470588235304, 0.0]]

In [6]:
datas = dataset.DocumentsDataset(xmls_path, ocr_path, image_path, candidate_path, field_dict)

Reading Annotations: 505it [00:00, 702.48it/s]
Attaching Candidate: 100%|██████████| 505/505 [00:01<00:00, 459.64it/s]
Attaching Neighbours: 100%|██████████| 505/505 [01:17<00:00,  4.73it/s]
normalizing position coordinates: 100%|██████████| 505/505 [00:02<00:00, 252.01it/s]
Parsing Input: 100%|██████████| 505/505 [00:00<00:00, 522.03it/s]


In [7]:
datas[0]

(tensor(1),
 tensor([0.6985, 0.0865]),
 tensor([1, 2, 3, 4, 5]),
 tensor([[ 0.0176, -0.0512],
         [ 0.0472, -0.0512],
         [ 0.0936, -0.0511],
         [ 0.1104, -0.0143],
         [-0.0133, -0.0145]]),
 1.0)

In [161]:
datal = data.DataLoader(datas, batch_size=2)

In [162]:
field_ids, c_cords, neighs, neigh_cords, labels = next(iter(datal))

In [67]:
f =(2,3,4)
a, b, c = f

In [70]:
c

4

In [None]:
X y 
y = [B, 1]
X = [B, (3 + N * 3)]


y words cords

words [B, N]
cords [B, N, 2]

words [B, N] = [B, N, D]
cords [B, N, 2 ] = [B, N, D]