In [2]:
import os
from data_source import preproc as pp
from multiprocessing import Pool
from functools import partial
import string
import h5py

In [1]:
class Dataset():
    """
        Dataset class functions:
            read_iam: reads the lines from iam Dataset saves test, valid and train set in a dict
            dataset structure : {train:{dt:[], gt:[]}, valid:{dt:[], gt:[]}, test:{dt:[], gt:[]}}
            
            preprocess_partitions preprocesses the data
    """
    
    def __init__(self, iam_path):
        
        self.iam_path = iam_path
        self.dataset = None
        self.partitions = ['train', 'valid', 'test']
        
    
    def read_iam(self):
        """reads iam dataset"""
        pt_path = os.path.join(self.iam_path, "partitions")
        paths = {"train": open(os.path.join(pt_path, "trainset.txt")).read().splitlines(),
                 "valid": open(os.path.join(pt_path, "validationset1.txt")).read().splitlines(),
                 "test": open(os.path.join(pt_path, "testset.txt")).read().splitlines()}
        
        lines = open(os.path.join(self.iam_path, "lines.txt")).read().splitlines()
        gt_dict = dict()
        
        for line in lines:
            if (not line or line[0]== "#"):
                continue
            
            splitted = line.split()
            
            if splitted[1] == "ok":
                gt_dict[splitted[0]] = " ".join(splitted[8::]).replace("|"," ")
            
        dataset = dict()
        
        for i in self.partitions:
            dataset[i] = {"dt": [], "gt": []}
            
            for line in paths[i]:
                try:
                    split = line.split("-")
                    
                    folder = f"{split[0]}-{split[1]}"
                    image = f"{split[0]}-{split[1]}-{split[2]}.png"
                    
                    image_path = os.path.join(self.iam_path, "lines", split[0], folder, image)
                    
                    dataset[i]['gt'].append(gt_dict[line])
                    dataset[i]['dt'].append(image_path)
                    
                except KeyError:
                    pass
                
        self.dataset = dataset
                
    
    def preprocess_partitions(self, input_size):
        """ function to preprocess the data, removes bad samples, preprocesses images"""
        
        print("Partitions will be preprocessed...")
        
        for i in self.partitions:
            
            arange = range(len(self.dataset[i]['gt']))
            
            for j in reversed(arange):
                #handles spaces around punctations
                text = pp.text_standardize(self.dataset[i]['gt'][j])
                
                if not self.check_text(text):
                    #remove if the example has more punctations than letters
                    self.dataset[i]['gt'].pop(j)
                    self.dataset[i]['dt'].pop(j)
                    continue
                
                self.dataset[i]['gt'][j] = text.encode()
                
            pool = Pool()
            #multiprocess: apllies pp.preprocess on each value of the array
            #partial -> changes args of function
            self.dataset[i]['dt'] = pool.map(partial(pp.preprocess, input_size=input_size), self.dataset[i]['dt'])
            pool.close()
            pool.join()
            
        print("Partitions preprocessing finished")
    
    def save(self, source_path):
        "saves partitions as hdf5"
        os.makedirs(os.path.dirname(source_path), exist_ok=True)
        
        for i in self.partitions:
            with h5py.File(source_path, "a") as hf:
                hf.create_dataset(f"{i}/dt", data=ds.dataset[i]['dt'], compression="gzip", compression_opts=9)
                hf.create_dataset(f"{i}/gt", data=ds.dataset[i]['gt'], compression="gzip", compression_opts=9)
                print(f"[OK] {i} partition.")

        print(f"Transformation finished.")
            
                
                
    
    @staticmethod
    def check_text(text):
        """Make sure text has more characters instead of punctuation marks"""

        strip_punc = text.strip(string.punctuation).strip()
        no_punc = text.translate(str.maketrans("", "", string.punctuation)).strip()

        if len(text) == 0 or len(strip_punc) == 0 or len(no_punc) == 0:
            return False

        punc_percent = (len(strip_punc) - len(no_punc)) / len(strip_punc)

        return len(no_punc) >= 2 and punc_percent <= 0.1
        