# Imports

In [26]:
import numpy as np
import pandas as pd
import matplotlib as plt
import cv2
import os
from PIL import Image
from torch.utils.data import Dataset

# Dataset Class Definition

In [21]:
class SROIEDataset(Dataset):
    def __init__(self, image_dir, entities_dir, box_dir, mask_dir=None, transform=None):
        self.image_dir = image_dir
        self.entities_dir = entities_dir
        self.box_dir = box_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
        self.mask_dir = mask_dir
        
        
    def __len__(self):
        return len(self.images)
    
    
    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        enties_path = os.path.join(self.entities_dir, self.images[index].replace('.jpg', '.txt'))
        box_path = os.path.join(self.entities_dir, self.images[index].replace('.jpg', '.txt'))
        
        with Image.open(img_path) as img:
            image = np.array(img.convert('RGB'))
        
        if self.transform is not None:
            augmentations = self.transform(image=image)
            image = augmentations['image']
        
        return image
    
    
    def grab_points(self, box_path):
        bbox_word_list = []
        with open(box_path, 'r', errors='ignore') as f:
            for line in f.read().splitlines():
                split_lines = line.split(',')
                
                bbox = np.array(split_lines[0:8], dtype=np.int32)
                text = ','.join(split_lines[8:])
                bbox_word_list.append([*bbox, text])
        
        bbox_df = pd.DataFrame(bbox_word_list, 
                               columns=['x0', 'y0', 'x1', 'y1', 'x2', 'y2', 'x3', 'y3', 'line'],
                              dtype=np.int16)
        
        bbox_df = bbox_df.drop(columns=['x1', 'y1', 'x3', 'y3'])
        return bbox_df

        
    
    def create_masks(self):
        for img in images:
            img_path = os.path.join(self.image_dir, img)
            img_read = cv2.imread(img_path)
            mask = np.zeros((image_read.shape[0], image_read.shape[1]), dtype=np.uint8)
        

In [31]:
dataset = SROIEDataset(
    image_dir='../data/SROIE2019/train/img/', 
    entities_dir='../data/SROIE2019/train/entities',
    box_dir='../data/SROIE2019/train/box'
)

example_bbox_df = dataset.grab_points(os.path.join(dataset.box_dir, 'X51005365179.txt'))
print(example_bbox_df)

     x0    y0   x2    y2                                  line
0    11    28  366    85                             3-1707067
1   448   163  616   193                            (481500-M)
2   113   203  660   229             C W KHOO HARDWARE SDN BHD
3   224   233  552   261             NO.50 , JALAN PBS 14/11 ,
4   109   266  666   294  KAWASAN PERINDUSTRIAN BUKIT SERDANG,
5   133   298  360   322                     TEL : 03-89410243
6   409   298  639   323                     FAX : 03-89410243
7   206   329  568   358            GST REG NO. : 000549584896
8   279   383  452   408                           TAX INVOICE
9    50   435  189   458                           INVOICE NO.
10  230   434  426   463                        : CR 1803/0064
11   52   468  112   491                                  DATE
12  231   468  509   491                 : 01-03-18 5:23:26 PM
13   52   502  188   524                           CASHIER NO.
14  230   501  308   525                               

  bbox_df = pd.DataFrame(bbox_word_list,
