In [7]:
import os
import re
import cv2
from torchvision import transforms, datasets
import albumentations as A
from albumentations.pytorch import ToTensorV2


numbers = re.compile(r'(\d+)')


def numericalSort(value):
    parts = numbers.split(value)
    parts[1::2] = map(int, parts[1::2])
    return parts


class OCRDataset():
    def __init__(self, img_dir, transform=None):

        self.img_dir = img_dir
        self.inp_h = 32
        self.inp_w = 128
       # print("image_dir: " + img_dir)

        self.img_names = sorted(os.listdir(
            img_dir), key=lambda x: str(x.split('.')[0]))

        self.img_names = sorted(self.img_names, key=numericalSort)

        self.transform = transform

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        image = cv2.imread(os.path.join(self.img_dir, img_name))
        #print(os.path.join(self.img_dir, img_name))

        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        img_h, img_w = image.shape
        image = cv2.resize(image, (0, 0), fx=self.inp_w / img_w,
                           fy=self.inp_h / img_h, interpolation=cv2.INTER_CUBIC)
        image = np.reshape(image, (self.inp_h, self.inp_w, 1))

        if self.transform is not None:
            image = self.transform(image=image)["image"]
            return image, img_name, idx

        image = image.transpose(2, 0, 1)
        # print(image.shape)

        return image, img_name, idx


In [8]:
# Path to dataset
DATA_PATH = "data/Bangla_writing_hnhtrd/"


# Albumentations noise
data_transform = A.Compose([
    A.ElasticTransform(alpha=0.5, sigma=0, alpha_affine=0, p=0.3),
    A.augmentations.transforms.GaussNoise(var_limit=(
        120.0, 135.0), mean=0, always_apply=False, p=0.6),
    A.augmentations.transforms.MotionBlur(blur_limit=(3, 6), p=0.3),
    ToTensorV2(),
])


train_dataset = OCRDataset(os.path.join(
    DATA_PATH, "train_img"), transform=data_transform)

for image, img_name, idx in train_dataset:
    print("Image -->", image)
    print("img_name -->", img_name)
    print("idx -->", idx)
    break

Image --> tensor([[[246, 246, 249,  ..., 254, 251, 251],
         [252, 246, 248,  ..., 252, 252, 252],
         [254, 254, 255,  ..., 255, 255, 255],
         ...,
         [250, 255, 255,  ..., 248, 242, 242],
         [250, 252, 252,  ..., 246, 247, 247],
         [255, 254, 249,  ..., 254, 254, 254]]], dtype=torch.uint8)
img_name --> 0.jpg
idx --> 0
