In [6]:
import os
import random

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from sklearn.model_selection import train_test_split




In [7]:
RANDOM_STATE = 42

In [62]:
images = os.listdir('./laba-dataset')
train_images, test_images = train_test_split(images, shuffle=True, train_size=0.75, random_state=RANDOM_STATE)
print('train size:', len(train_images))
print('test size:', len(test_images))


chars = list(map(lambda x: x[:-4], images))
chars = ''.join(chars)
chars = set(chars)
print('unique chars:', len(chars))

idx2char = {k: v for k, v in enumerate(chars, start=0)}
char2idx = {k: v for v, k in idx2char.items()}

train size: 802
test size: 268
unique chars: 19


### Dataset and DataLoader

In [50]:
class CaptchaDataset(Dataset):
    def __init__(self, images, img_dir='./laba-dataset/', transform=None):
        self.images = images
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        image = Image.open(img_path).convert('L')
        label = self.images[idx][:-4]
        if self.transform:
            image = self.transform(image)
        return image, label

In [51]:
transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=0.5, std=0.25)
        ])

train_dataset = CaptchaDataset(train_images, transform=transform)
test_dataset = CaptchaDataset(test_images, transform=transform)

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

train size: 802
test size: 268


In [63]:
x, y =next(iter(train_dataloader))
x, y

(tensor([[[[1.0118, 1.0118, 1.0118,  ..., 1.9373, 1.9373, 1.9373],
           [1.0118, 1.0118, 1.0118,  ..., 1.9373, 1.9373, 1.9373],
           [1.0118, 1.0118, 1.0118,  ..., 1.9373, 1.9373, 1.9373],
           ...,
           [1.0588, 1.0588, 1.0588,  ..., 1.9843, 1.9843, 1.9843],
           [1.0588, 1.0588, 1.0588,  ..., 1.9843, 1.9843, 1.9843],
           [1.0588, 1.0588, 1.0588,  ..., 1.9843, 1.9843, 1.9843]]],
 
 
         [[[1.0118, 1.0118, 1.0118,  ..., 1.9373, 1.9373, 1.9373],
           [1.0118, 1.0118, 1.0118,  ..., 1.9373, 1.9373, 1.9373],
           [1.0118, 1.0118, 1.0118,  ..., 1.9373, 1.9373, 1.9373],
           ...,
           [1.0588, 1.0588, 1.0588,  ..., 1.9843, 1.9843, 1.9843],
           [1.0588, 1.0588, 1.0588,  ..., 1.9843, 1.9843, 1.9843],
           [1.0588, 1.0588, 1.0588,  ..., 1.9843, 1.9843, 1.9843]]],
 
 
         [[[1.0118, 1.0118, 1.0118,  ..., 1.9373, 1.9373, 1.9373],
           [1.0118, 1.0118, 1.0118,  ..., 1.9373, 1.9373, 1.9373],
           [1.0118

### Model