In [None]:
from typing import Tuple
from torch.utils.data import Dataset
import pathlib
from PIL import Image
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import pandas as pd 

class chinese_class():
    def __init__(self, file_path):
        self.class_df = []
        class_dict = {
            "id": [],
            "char": [],
            "hex": [],
            "uni": [],
            "label": [],
        }
        with open(file_path, "r") as f:
            f.readline()
            for line in f:
                id, char, hex, uni, label = line.split()
                class_dict["id"].append(int(id))
                class_dict["char"].append(char)
                class_dict["hex"].append(hex)
                class_dict["uni"].append(uni)
                class_dict["label"].append(label)

        self.class_df = pd.DataFrame(class_dict)

    def get_class_name_from_path(self, image_path):
        return self.class_df.iloc[int(image_path.parent.stem)]["label"]
    
    def get_classes(self):
        classes = list(self.class_df["label"])
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx
    
class Chinese_Dataset(Dataset):
    
    def __init__(self, targ_dir: str, set:str, transform=None) -> None:
        self.chinese = chinese_class(pathlib.Path(targ_dir) / "952_labels.txt")
        self.paths = list((pathlib.Path(targ_dir)/ f"952_{set}").glob("*/*.png"))
        self.transform = transform
        self.classes, self.class_to_idx = self.chinese.get_classes()

    def load_image(self, index: int) -> Image.Image:
        "Opens an image via a path and returns it."
        image_path = self.paths[index]
        return Image.open(image_path) 
    
    def __len__(self) -> int:
        "Returns the total number of samples."
        return len(self.paths)
    
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
        "Returns one sample of data, data and label (X, y)."
        img = self.load_image(index)
        class_name  = self.chinese.get_class_name_from_path(self.paths[index])
        class_idx = self.class_to_idx[class_name]

        if self.transform:
            return self.transform(img), class_idx
        else:
            return img, class_idx

train_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    # transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor()
])

train_data = Chinese_Dataset("etl_952_singlechar_size_64", "train", transform=train_transforms)

class_names = train_data.classes
class_dict = train_data.class_to_idx

In [None]:
train_data.chinese.class_df["label"].value_counts()

In [None]:
len(train_data)

In [None]:
from typing import List
import matplotlib.pyplot as plt
import random
# 1. Take in a Dataset as well as a list of class names
def display_random_images(dataset: torch.utils.data.dataset.Dataset,
                          classes: List[str] = None,
                          n: int = 10,
                          display_shape: bool = True,
                          seed: int = None):
    
    # 2. Adjust display if n too high
    if n > 10:
        n = 10
        display_shape = False
        print(f"For display purposes, n shouldn't be larger than 10, setting to 10 and removing shape display.")
    
    # 3. Set random seed
    if seed:
        random.seed(seed)

    # 4. Get random sample indexes
    random_samples_idx = random.sample(range(len(dataset)), k=n)

    # 5. Setup plot
    plt.figure(figsize=(16, 8))

    # 6. Loop through samples and display random samples 
    for i, targ_sample in enumerate(random_samples_idx):
        targ_image, targ_label = dataset[targ_sample][0], dataset[targ_sample][1]

        # 7. Adjust image tensor shape for plotting: [color_channels, height, width] -> [color_channels, height, width]
        targ_image_adjust = targ_image.permute(1, 2, 0)

        # Plot adjusted samples
        plt.subplot(1, n, i+1)
        plt.imshow(targ_image_adjust)
        plt.axis("off")
        if classes:
            title = f"class: {classes[targ_label]}"
            if display_shape:
                title = title + f"\nshape: {targ_image_adjust.shape}"
        plt.title(title)


display_random_images(train_data, 
                      n=5, 
                      classes=class_names,)

In [None]:
import numpy 

def compute_mean_std(dataset):
    """compute the mean and std of cifar100 dataset
    Args:
        cifar100_training_dataset or cifar100_test_dataset
        witch derived from class torch.utils.data

    Returns:
        a tuple contains mean, std value of entire dataset
    """

    data_r = numpy.dstack([dataset[i][0][0, :, :] for i in range(len(dataset))])
    data_g = numpy.dstack([dataset[i][0][1, :, :] for i in range(len(dataset))])
    data_b = numpy.dstack([dataset[i][0][2, :, :] for i in range(len(dataset))])
    mean = numpy.mean(data_r), numpy.mean(data_g), numpy.mean(data_b)
    std = numpy.std(data_r), numpy.std(data_g), numpy.std(data_b)

    return mean, std

compute_mean_std(train_data)

In [None]:
import numpy 

def compute_mean_std(dataset):
    """compute the mean and std of cifar100 dataset
    Args:
        cifar100_training_dataset or cifar100_test_dataset
        witch derived from class torch.utils.data

    Returns:
        a tuple contains mean, std value of entire dataset
    """

    data_l = numpy.dstack([dataset[i][0][0, :, :] for i in range(len(dataset))])
    mean = numpy.mean(data_l)
    std = numpy.std(data_l)

    return mean, std

compute_mean_std(train_data)

In [None]:
import editdistance
editdistance.eval('banana', 'bahama')