In [19]:
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.transforms import transforms
import os

In [20]:
def font_to_label(font_name):
    font_mapping = {
        'cs': 0,
        'ks': 1,
        'ls': 2,
        'xs': 3,
        'zs': 4,
        # Add more font mappings as needed
    }
    return font_mapping.get(font_name, -1)  # Return -1 for unknown fonts

In [None]:
class CalligraphyDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.char_to_idx = {}
        self.idx_to_char = {}
        #0 cs; 1  ks; 2 ls; 3 xs; 4 zs
        self.labels_font = []

        supported_exts = ('.jpg', '.jpeg', '.png', '.gif', '.bmp')
        char_idx = 0
        for char_name in sorted(os.listdir(data_dir)):
            char_dir = os.path.join(data_dir, char_name)
            if os.path.isdir(char_dir):
                if char_name not in self.char_to_idx:
                    self.char_to_idx[char_name] = char_idx
                    self.idx_to_char[char_idx] = char_name
                    char_idx += 1

                for root, _, files in os.walk(char_dir):
                    for file in files:
                        if file.lower().endswith(supported_exts):
                            self.image_paths.append(os.path.join(root, file))
                            self.labels_font.append(font_to_label(file[0:2]))
                            self.labels.append(self.char_to_idx[char_name])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Warning: Skipping corrupted image {img_path}: {e}")
            return self.__getitem__((idx + 1) % len(self))

        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label


In [22]:
data_transforms = {
        'train': transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomRotation(15),
            transforms.RandomAffine(degrees=0, translate=(0.2, 0.2), scale=(0.7, 1.3)),
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
            transforms.RandomGrayscale(p=0.1),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225])
        ]),
    }

full_dataset = CalligraphyDataset('chinese_fonts', transform=data_transforms['train'])

In [None]:
full_dataset.labels_font = [font_to_label(font) for font in full_dataset.labels_font]

['cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'ks',
 'ks',
 'ks',
 'ks',
 'ks',
 'ks',
 'ks',
 'ls',
 'ls',
 'ls',
 'ls',
 'xs',
 'xs',
 'xs',
 'xs',
 'xs',
 'xs',
 'xs',
 'xs',
 'xs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'ks',
 'ks',
 'ks',
 'ks',
 'ks',
 'ks',
 'ks',
 'ks',
 'ks',
 'ks',
 'ks',
 'ls',
 'ls',
 'ls',
 'ls',
 'ls',
 'ls',
 'ls',
 'ls',
 'ls',
 'xs',
 'xs',
 'xs',
 'xs',
 'xs',
 'xs',
 'xs',
 'xs',
 'xs',
 'xs',
 'xs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'zs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',
 'cs',