# Imports

In [38]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import os
from PIL import Image
from torch.utils.data import Dataset

from tqdm.notebook import tqdm

# Dataset Class Definition

In [43]:
class SROIEDataset(Dataset):
    def __init__(self, image_dir, entities_dir, box_dir, mask_dir=None, transform=None):
        self.image_dir = image_dir
        self.entities_dir = entities_dir
        self.box_dir = box_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
        self.mask_dir = mask_dir
        
        
    def __len__(self):
        return len(self.images)
    
    
    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace('jpg', '_mask.jpg'))
        enties_path = os.path.join(self.entities_dir, self.images[index].replace('.jpg', '.txt'))
        box_path = os.path.join(self.entities_dir, self.images[index].replace('.jpg', '.txt'))
        
        with Image.open(img_path) as img:
            image = np.array(img.convert('RGB'))
            mask = np.array(img.convert('L'), dtype=np.float32)
        
        mask[mask == 255.0] = 1.0     
        if self.transform is not None:
            augmentations = self.transform(image=image)
            image = augmentations['image']
            mask = augemntations['mask']
        
        return image, mask
    
    
    def create_bbox_df(self, box_path):
        bbox_word_list = []
        with open(box_path, 'r', errors='ignore') as f:
            for line in f.read().splitlines():
                if len(line) == 0:
                    continue
                
                split_lines = line.split(',')
                
                bbox = np.array(split_lines[0:8], dtype=np.int32)
                text = ','.join(split_lines[8:])
                bbox_word_list.append([*bbox, text])
        
        bbox_df = pd.DataFrame(bbox_word_list, 
                               columns=['x0', 'y0', 'x1', 'y1', 'x2', 'y2', 'x3', 'y3', 'line']
                              )
        
        bbox_df = bbox_df.drop(columns=['x1', 'y1', 'x3', 'y3'])
        return bbox_df

    
    def create_masks(self):
        for img in tqdm(self.images):
            img_path = os.path.join(self.image_dir, img)
            box_path = os.path.join(self.box_dir, img.replace('.jpg', '.txt'))
            bbox_df = self.create_bbox_df(box_path=box_path)
            img_read = cv2.imread(img_path)
            mask = np.zeros((img_read.shape[0], img_read.shape[1]), dtype=np.uint8)
            for index, row in bbox_df.iterrows():
                mask[row['y0']:row['y2'], row['x0']:row['x2']] = 255
            
            if self.mask_dir is None:
                self.mask_dir = '../data/SROIE2019/train/mask/'
                os.mkdir(self.mask_dir)
            
            mask_path = os.path.join(self.mask_dir, img.replace('.jpg', '_mask.jpg'))   
            if not os.path.exists(mask_path):
                cv2.imwrite(mask_path, mask)
        

In [44]:
dataset = SROIEDataset(
    image_dir='../data/SROIE2019/train/img/', 
    entities_dir='../data/SROIE2019/train/entities',
    box_dir='../data/SROIE2019/train/box',
    mask_dir='../data/SROIE2019/train/mask/'
)

dataset.create_masks()
#example_bbox_df = dataset.create_bbox_df(os.path.join(dataset.box_dir, 'X51005365179.txt'))
#print(example_bbox_df)

  0%|          | 0/626 [00:00<?, ?it/s]

Corrupt JPEG data: bad Huffman code
