In [1]:
'''
Location of 'project' directory.
'''
import os

os.chdir('..')
os.chdir('..')

BASE_DIR = os.getcwd()

# BASE_DIR is working directory for this notebook
print(os.getcwd())

D:\ML\Projects\htr\htr-github


## Import libraries

In [2]:
import os
import json
import cv2
import numpy as np
import h5py
from htr.utils import split_data
from htr.preprocessing import preprocess
from tqdm.notebook import tqdm

## Read data (BanglaWriting dataset)

In [3]:
BANGLAWRITING_DIR = os.path.join(BASE_DIR, 'raw', 'BanglaWriting')

In [37]:
def is_img_valid(img):
    try:
        if img.shape[0] > 0 and img.shape[1] > 0:
            return True
    except:
        pass
    return False

In [49]:
samples = []

for root, dirs, files in os.walk(BANGLAWRITING_DIR):
    for file in tqdm(files):
        if file.endswith('.jpg'):
            img = cv2.imread(os.path.join(root, file), cv2.IMREAD_GRAYSCALE)
            if img is None:
                print(f'Couldn\'t load image {file}')
                continue
            filename = file.split('.')[0]
            
        elif file.endswith('.json') and filename == file.split('.')[0]:
            with open(os.path.join(root, file), encoding='utf-8') as jf:
                data = json.load(jf)
                for shape in data['shapes']:
                    (xmin, ymin), (xmax, ymax) = shape['points']
                    sample = {}
                    sample['img'] = img[int(ymin):int(ymax), int(xmin):int(xmax)]
                    sample['gt_text'] = shape['label']
                    
                    if not is_img_valid(sample['img']):
                        print(f"Invalid word image ({sample['gt_text']}) on file {file}")
                        continue

                    samples.append(sample)
                    
print('# of samples:', len(samples))

HBox(children=(FloatProgress(value=0.0, max=520.0), HTML(value='')))

Invalid word image (*) on file 126_17_0.json
Invalid word image (*) on file 126_17_0.json
Invalid word image (*) on file 126_17_0.json
Invalid word image (*) on file 128_16_1.json
Invalid word image (শ্রেষ্ঠ) on file 183_16_0.json
Invalid word image (*) on file 231_14_1.json
Invalid word image (*) on file 234_17_1.json
Invalid word image (*) on file 234_17_1.json
Invalid word image (*) on file 234_17_1.json
Invalid word image (*) on file 234_17_1.json
Invalid word image (ব) on file 256_14_1.json
Invalid word image (*) on file 68_12_0.json
Invalid word image (*) on file 80_14_0.json

# of samples: 21221


## Split data

In [50]:
dataset = split_data(samples, val_split_size=0.02, test_split_size=0.02)

print('Train images:', len(dataset['train']))
print('Validation images:', len(dataset['val']))
print('Test images:', len(dataset['test']))

Train images: 20373
Validation images: 424
Test images: 424


## Preprocess samples and save in a HDF5 file

In [6]:
# Parameters
dataset_path = os.path.join(BASE_DIR, 'htr', 'data', 'BanglaWriting.hdf5')
input_size = (1024, 128, 1)
max_text_len = 128
batch_size = 1024

*HDF5 file structure*

    BanglaWriting.hdf5
       │    
       ├───train
       │   ├───imgs
       │   └───gt_texts
       ├───val
       │   ├───imgs
       │   └───gt_texts
       ├───test
           ├───imgs
           └───gt_texts

In [53]:
with h5py.File(dataset_path, 'w') as hf:
    for s in ['train', 'val', 'test']:
        # Dummy imgs
        hf.create_dataset(f'{s}/imgs',
                          data=np.zeros(shape=(len(dataset[s]), input_size[0], input_size[1]), dtype=np.uint8),
                          compression='gzip',
                          compression_opts=9)
        # Dummy ground truth texts
        hf.create_dataset(f'{s}/gt_texts',
                          data=[('c' * max_text_len).encode()] * len(dataset[s]), 
                          compression='gzip', 
                          compression_opts=9)

In [54]:
for s in ['train', 'val', 'test']:
    print(s)
    for batch in tqdm(range(0, len(dataset[s]), batch_size)):
        imgs = [preprocess(sample['img'], input_size) for sample in dataset[s][batch : batch + batch_size]]
        gt_texts = [sample['gt_text'].encode() for sample in dataset[s][batch : batch + batch_size]]
        
        with h5py.File(dataset_path, 'a') as hf:
            hf[f'{s}/imgs'][batch : batch + batch_size] = imgs
            hf[f'{s}/gt_texts'][batch : batch + batch_size] = gt_texts

train


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


val


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


test


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


