# Imports

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import os
from PIL import Image
from torch.utils.data import Dataset
import json

from tqdm.notebook import tqdm

# Dataset Class Definition

In [34]:
class FUNSDDataset(Dataset):
    def __init__(self, image_dir, annotations_dir, mask_dir=None, transform=None):
        self.image_dir = image_dir
        self.annotations_dir = annotations_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
        self.mask_dir = mask_dir
        
        
    def __len__(self):
        return len(self.images)
    
    
    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace('jpg', '_mask.jpg'))
        
        with Image.open(img_path) as img:
            image = np.array(img.convert('RGB'))
            mask = np.array(img.convert('L'), dtype=np.float32)
        
        mask[mask == 255.0] = 1.0     
        if self.transform is not None:
            augmentations = self.transform(image=image)
            image = augmentations['image']
            mask = augemntations['mask']
        
        return image, mask
    
    
    def create_bbox_df(self, box_path):
        bbox_word_list = []
        with open(box_path, 'r', errors='ignore') as f:
            data = json.load(f)
            for form in data['form']:
                for words in form['words']:
                    bbox_word_list.append([*words['box'], words['text']])
            
        bbox_df = pd.DataFrame(bbox_word_list,
                              columns=['x0', 'y0', 'x2', 'y2', 'line'])
        
        return bbox_df

    
    def create_masks(self):
        for img in tqdm(self.images):
            img_path = os.path.join(self.image_dir, img)
            annotations_path = os.path.join(self.annotations_dir, img.replace('.png', '.json'))
            bbox_df = self.create_bbox_df(box_path=annotations_path)
            
            img_read = cv2.imread(img_path)
            mask = np.zeros((img_read.shape[0], img_read.shape[1]), dtype=np.uint8)
            for index, row in bbox_df.iterrows():
                mask[row['y0']:row['y2'], row['x0']:row['x2']] = 255
            
            if self.mask_dir is None:
                self.mask_dir = '../data/FUNSD/training_data/mask/'
                os.mkdir(self.mask_dir)
            
            mask_path = os.path.join(self.mask_dir, img.replace('.png', '_mask.png'))   
            if not os.path.exists(mask_path):
                cv2.imwrite(mask_path, mask)
        

In [35]:
dataset = FUNSDDataset(
    image_dir='../data/FUNSD/training_data/images/', 
    annotations_dir='../data/FUNSD/training_data/annotations'
    #mask_dir='../data/FUNSD/training_data/mask/'
)

dataset.create_masks()
#example_bbox_df = dataset.create_bbox_df(os.path.join(dataset.annotations_dir, '0000971160.json'))
#print(example_bbox_df)

  0%|          | 0/149 [00:00<?, ?it/s]