In [1]:
from typing import Tuple
from torch.utils.data import Dataset
import pathlib
from PIL import Image
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import pandas as pd 
import numpy
import string

class Cangjie_Class():
    CHARS = 'abcdefghijklmnopqrstuvwxyz'
    CHAR2LABEL = {char: i + 1 for i, char in enumerate(CHARS)}
    LABEL2CHAR = {label: char for char, label in CHAR2LABEL.items()}

    def __init__(self, file_path):

        self.class_df = []

        self.char_list = string.ascii_letters[:26]
        class_dict = {
            "id": [],
            "char": [],
            "hex": [],
            "uni": [],
            "label": [],
        }
        with open(file_path, "r") as f:
            f.readline()
            for line in f:
                id, char, hex, uni, label = line.split()
                class_dict["id"].append(int(id))
                class_dict["char"].append(char)
                class_dict["hex"].append(hex)
                class_dict["uni"].append(uni)
                class_dict["label"].append(label)

        self.class_df = pd.DataFrame(class_dict)

        self.class_df["cls"] = self.class_df.apply(lambda row: 0 if row["label"] == "zc" else 1, axis=1)

    def get_class_name_from_path(self, image_path):
        return self.class_df.iloc[int(image_path.parent.stem)]["label"]
    
    def encode_to_labels(self, txt):
        dig_lst = []
        for index, char in enumerate(txt):
            try:
                dig_lst.append(self.char_list.index(char))
            except:
                print(char)

        while len(dig_lst) < 7:
            dig_lst.append(len(self.char_list))
            
        return dig_lst
    
    def decode_to_classname(self, dig_lst):
        class_name = ''

        for dig in dig_lst[:5]:
            try:
                class_name += self.char_list[int(dig)]
            except:
                class_name += '_'

        return class_name
    
    def get_classes(self):
        classes = list(self.class_df["label"])
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        idx_to_class = {i:cls_name for i, cls_name in enumerate(classes)}
        return classes, class_to_idx, idx_to_class
    
class Cangjie_Dataset(Dataset):
    
    def __init__(self, targ_dir: str, set:str, transform=None) -> None:
        self.cangjie = Cangjie_Class(pathlib.Path(targ_dir) / "952_labels.txt")
        self.paths = list((pathlib.Path(targ_dir)/ f"952_{set}").glob("*/*.png"))
        self.transform = transform
        self.classes, self.class_to_idx, self.idx_to_class = self.cangjie.get_classes()

    def load_image(self, index: int) -> Image.Image:
        "Opens an image via a path and returns it."
        image_path = self.paths[index]
        return Image.open(image_path)
    
    def __len__(self) -> int:
        "Returns the total number of samples."
        return len(self.paths)
    
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
        "Returns one sample of data, data and label (X, y)."
        image = self.load_image(index)

        text = self.cangjie.get_class_name_from_path(self.paths[index])
        target = [self.cangjie.CHAR2LABEL[c] for c in text]
        target_length = [len(target)]

        target = torch.LongTensor(target)
        target_length = torch.LongTensor(target_length)
        return image, target, target_length

In [3]:
test_set = Cangjie_Dataset("etl_952_singlechar_size_64", "test")

In [1]:
from cangjie_dataset import Cangjie_Dataset

val_dataset = Cangjie_Dataset("etl_952_singlechar_size_64", "val")

In [3]:
image, cls, labels = val_dataset[0]

torch.Size([5]) tensor([0.]) zc___


In [8]:
import torch
output = torch.Tensor([i for i in range(27)])

print(output)

tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
        14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26.])


In [None]:
from torch import nn

In [6]:
print(labels[0] + labels[1])

tensor(29)


In [5]:
label = '1'
label += '_' * (5 - len(label))
print(label)

1____
