In [1]:
import os
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [3]:
class HandwritingDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.image_dir = os.path.join(root_dir, 'Images')
        self.stroke_dir = os.path.join(root_dir, 'Strokes')
        self.transform = transform
        self.image_files = [f for f in os.listdir(self.image_dir) if f.endswith('.png')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.image_files[idx])
        stroke_name = os.path.join(self.stroke_dir, self.image_files[idx].replace('.png', '.npy'))

        image = Image.open(img_name).convert('L')
        stroke_data = np.load(stroke_name)

        if self.transform:
            image = self.transform(image)

        return image, stroke_data

# Define the transformation for your images (resizing, normalization, etc.)
transform = transforms.Compose([
    transforms.ToTensor(),
])

In [4]:
# Create the dataset
data_path = '../../../DataSet/IAM-Online/Resized_Dataset/Train/' # The path should point to the Train directory (the directly available subfolders should be Images/ and Strokes/)
dataset = HandwritingDataset(root_dir=data_path, transform=transform)

# Calculate the split indices for train and test
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# Create dataloaders
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)

# Check the number of batches in the train and test loaders
print(len(train_loader), len(test_loader))

1096 274


In [5]:
print(train_loader)

<torch.utils.data.dataloader.DataLoader object at 0x000002476A7CDA30>


In [6]:
print(test_loader)

<torch.utils.data.dataloader.DataLoader object at 0x000002476A7CD7F0>
