In [None]:
import cv2
import numpy as np
import os
import pandas as pd
import pickle

from torch.utils.data import Dataset


# Dataset

In [None]:
class NomImageDataset(Dataset):
    def __init__(self, image_dir, annotation_dir, unicode_dict_path, transform=None):
        self.root_dir = image_dir
        self.annotations = dict()

        with open(unicode_dict_path, 'r') as f:
            self.unicode_dict = pickle.load(f)
        for idx, k, v in enumerate(self.unicode_dict.items()):
            self.unicode_dict[k] = idx

        
        for image_path in os.listdir(image_dir):
            image_name = image_path.split('.')[0]
            annotation_path = os.path.join(annotation_dir, image_name + '.txt')

            with open(annotation_path, 'r') as f:
                img_annotations = []
                for line in f:
                    line = line.strip()
                    if line:
                        # Line format: tl, tr, br, bl, unicode
                        x_tl, y_tl, x_br, y_br, unicode = line.strip().split(',')
                        x_tl, y_tl, x_br, y_br = int(x_tl), int(y_tl), int(x_br), int(y_br)
                        unicode = self.unicode_dict[unicode]
                        img_annotations.append((x_tl, y_tl, x_br, y_br, unicode))

            self.annotations[image_name] = img_annotations

        self.image_list = list(self.annotations.keys())
        self.transform = transform

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image_name = self.image_list[idx]
        image_path = os.path.join(self.root_dir, image_name + '.jpg')
        image = cv2.imread(image_path)
        # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        annotations = self.annotations[image_name]

        if self.transform:
            image = self.transform(image)

        return image, annotations
    
dataset = NomImageDataset('data/images', 'data/annotations', 'data/unicode_dict.pkl')