In [1]:
import cv2
import numpy as np
import pandas as pd
import xml.etree.ElementTree as ET
from xml.dom import minidom
import scipy.io as sio
import os
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.utils.data as data
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import transforms, datasets

import albumentations as A
from albumentations.pytorch import ToTensorV2

### MJSynth

In [2]:
class MJSynthDataset(data.Dataset):
    def __init__(self, img_dir, annotation_file, transform=None):

#         self.s_img_dir = '/home/ec2-user/word_level_ocr/computer/datasets/S2S_data/MJSYNTH/mnt/ramdisk/max/90kDICT32px/'
        self.s_img_dir = img_dir
        self.inp_h = 32
        self.inp_w = 128
        self.mean = np.array(0.588, dtype=np.float32)
        self.std = np.array(0.193, dtype=np.float32)
        
#         mjsynth_annotation_file = '/home/ec2-user/word_level_ocr/computer/datasets/S2S_data/MJSYNTH/mnt/ramdisk/max/90kDICT32px/annotation_train.txt'
        mjsynth_annotation_file = annotation_file

        with open(mjsynth_annotation_file, 'r') as file:
            s_images = file.readlines()
        self.s_images = sorted(s_images)

        self.transform = transform

    def __len__(self):
        return len(self.s_images)

    def __getitem__(self, idx):
        # for training data (synthetic)
        try:
            img_name = self.s_images[idx].split(' ')[0].split('./')[1]
            image = cv2.imread(os.path.join(self.s_img_dir, img_name))
            image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        except:
            img_name = self.s_images[idx-1].split(' ')[0].split('./')[1] #random to avoid error
            image = cv2.imread(os.path.join(self.s_img_dir, img_name))
            image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

        img_h, img_w = image.shape

        image = cv2.resize(image, (0,0), fx=self.inp_w / img_w, fy=self.inp_h / img_h, interpolation=cv2.INTER_CUBIC)
        image = np.reshape(image, (self.inp_h, self.inp_w, 1))

        if self.transform is not None:
            image = self.transform(image = image)["image"]#["image"]
            return image, img_name, idx

        image = image.transpose(2, 0, 1)
        
        
        return image, img_name, idx

### IIIT5k

In [None]:
train_data = sio.loadmat("/home/ec2-user/word_level_ocr/computer/datasets/S2S_data/IIIT5K/traindata.mat")
test_data = sio.loadmat("/home/ec2-user/word_level_ocr/computer/datasets/S2S_data/IIIT5K/testdata.mat")

In [None]:
print(train_data)

In [None]:
print(len(train_data['traindata'][0]))
print(len(test_data['testdata'][0]))

In [None]:
images = train_data['traindata'][0]
print(images[1][0][0])
print(images[0][1][0])

In [None]:
class IIIT5kDataset(data.Dataset):
    def __init__(self, img_dir, annotation_file, train=True, transform=None):

#         self.s_img_dir = '/home/ec2-user/word_level_ocr/computer/datasets/S2S_data/MJSYNTH/mnt/ramdisk/max/90kDICT32px/'
        self.img_dir = img_dir
        self.inp_h = 32
        self.inp_w = 128
        self.mean = np.array(0.588, dtype=np.float32)
        self.std = np.array(0.193, dtype=np.float32)
        
#         mjsynth_annotation_file = '/home/ec2-user/word_level_ocr/computer/datasets/S2S_data/MJSYNTH/mnt/ramdisk/max/90kDICT32px/annotation_train.txt'
        iiit5k_annotation_file = annotation_file
        data = sio.loadmat(iiit5k_annotation_file)
        
        if train:
            self.images = data['traindata'][0]
        else:
            self.images = data['testdata'][0]

        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):

        img_name = self.images[idx][0][0]
        image = cv2.imread(os.path.join(self.img_dir, img_name))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        img_h, img_w = image.shape
        image = cv2.resize(image, (0,0), fx=self.inp_w / img_w, fy=self.inp_h / img_h, interpolation=cv2.INTER_CUBIC)
        image = np.reshape(image, (self.inp_h, self.inp_w, 1))

        if self.transform is not None:
            image = self.transform(image = image)["image"]#["image"]
            return image, img_name, idx

        image = image.transpose(2, 0, 1)
        
        
        return image, img_name, idx

In [None]:
# img_dir = "/home/ec2-user/word_level_ocr/pritom/datasets/out_50000_Synthetic_controlled_12lakh"
# img_dir = '/home/ec2-user/word_level_ocr/computer/datasets/S2S_data/MJSYNTH/mnt/ramdisk/max/90kDICT32px/'
img_dir = '/home/ec2-user/word_level_ocr/computer/datasets/S2S_data/IIIT5K/'

# annotation_file = img_dir + 'annotation_train.txt'
annotation_file = img_dir + 'traindata.mat'

#Batch Size variable
train_batch_s = 1
valid_batch_s = 1

##Albumentations noise
data_transform = A.Compose([
        A.augmentations.transforms.GaussNoise(var_limit=(120.0, 135.0), mean=0, always_apply=False,p=0.5),
        A.imgaug.transforms.IAAAdditiveGaussianNoise(loc=1, scale=(2.5500000000000003, 12.75), per_channel=False, always_apply=False, p=0.5),
        A.augmentations.transforms.MotionBlur(p=0.5),
        ToTensorV2(),
    ])

# ocr_dataset = OCRDataset(img_dir, transform=data_transform)
# ocr_dataset = MJSynthDataset(img_dir=img_dir, annotation_file=annotation_file, transform=None)
ocr_dataset = IIIT5kDataset(img_dir=img_dir, annotation_file=annotation_file, transform=None)

random_seed= 42

# Creating data indices for training and validation splits
validation_split = .15
shuffle_dataset = True
dataset_size = len(ocr_dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(ocr_dataset, batch_size= train_batch_s, 
                                           sampler=train_sampler, num_workers = 4)
validation_loader = torch.utils.data.DataLoader(ocr_dataset, batch_size= valid_batch_s,
                                                sampler=valid_sampler, num_workers = 4)

In [None]:
how_many_to_plot = 20

plt.figure(figsize=(50,50))
for i, batch in enumerate(train_loader, start=1):
    image, label, _ = batch
    plt.subplot(10,10,i)
    plt.imshow(image[0].reshape(32,128), cmap='gray')
    plt.axis('off')
    plt.title(label, fontsize=24)
    if (i >= how_many_to_plot): break
plt.show()

### SVT (Create Dataset)

In [74]:
import os
from bs4 import BeautifulSoup
import cv2
from PIL import Image
from tqdm import tqdm


data_dir = "/home/ec2-user/word_level_ocr/computer/datasets/S2S_data/svt1/"
infile=open(os.path.join(data_dir, "train.xml"), "r")
# infile=open(os.path.join(data_dir, "test.xml"), "r")

extracted = "extracted/train/"
# extracted = "extracted/test/"

try:
    # Create  Directory  MyDirectory 
    os.makedirs(os.path.join(data_dir, extracted))
except FileExistsError:
    ##print if directory already exists
    print("Directory already exists.")
        
contents = infile.read()
soup = BeautifulSoup(contents,'xml')
img_names = soup.find_all('imageName')
count = 0

for name in tqdm(img_names):
        
    #print(name)
    image_path = os.path.join(data_dir, name.text)
    #print(image_path)
    image = cv2.imread(image_path)
    # display(Image.fromarray(image))
    words = name.find_next_sibling("taggedRectangles").find_all("taggedRectangle")
    for word in words:
        #print(str(word))
        x = int(word["x"])
        y = int(word["y"])
        h = int(word["height"])
        w = int(word["width"])
        
        word_img = image[y:h+y, x:x+w]
        #display(Image.fromarray(word_img))
        word_name = word.find_next("tag").text
        #print(word_name)
        
        if word_img.shape[0] == 0 or word_img.shape[1] == 0:
            continue
            
        cv2.imwrite(os.path.join(data_dir, extracted, str(count) + "_" + word_name + ".jpg"), word_img)
        
        count+=1

# l = soup.tagset()
# l[0].get_text()


100%|██████████| 100/100 [00:02<00:00, 49.38it/s]
