In [1]:
# modules import
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
import string
import torch.optim as optim
from tqdm import tqdm
import torchmetrics
from torchinfo import summary

In [2]:
#symbols
dictionary = string.ascii_lowercase + string.digits

In [3]:
#image processing
class Captcha(Dataset):
    def __init__(self, dataset_path):
        self.images = os.listdir(dataset_path)
        self.directory = dataset_path
        self.transform = transforms.Compose([transforms.PILToTensor()])
    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image = self.images[index]

        X = Image.open(self.directory + "/" + image).convert('RGB')
        y = image.split(".")[0]
        matrix = torch.zeros(len(y), len(dictionary))
        for i in range(len(y)):
            j = dictionary.find(y[i])
            matrix[i][j] = 1 
        return self.transform(X).float().to("cuda:0"), matrix.to("cuda:0")

In [4]:
dataset_path = "samples"

In [5]:
from torch.utils.data import random_split
def load_and_split_data(dataset_path: str):
    captcha_dataset = Captcha(dataset_path)
    number_of_images = len(captcha_dataset.images)
    train_size = int(round(number_of_images * 0.8))
    print(train_size)
    train_dataset, test_dataset = random_split(captcha_dataset, [train_size, number_of_images - train_size])
    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=32)
    test_loader = DataLoader(test_dataset, batch_size=32)
    return train_loader, test_loader

In [6]:
train_loader, test_loader = load_and_split_data(dataset_path)

856


In [7]:
#model
class CaptchaModel(torch.nn.Module):
    def __init__(self):
        super(CaptchaModel, self).__init__()
        self.cnn1 = nn.Conv2d(3, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.cnn2 = nn.Conv2d(64, 128, 3, padding=1)
        self.batchnorm1 = nn.BatchNorm2d(128)
        self.cnn3 = nn.Conv2d(128, 256, 3, padding=1)
        self.pool2 = nn.MaxPool2d((2,3))
        self.cnn4 = nn.Conv2d(256, 512, 3, padding=1)
        self.batchnorm2 = nn.BatchNorm2d(512)
        self.fc1 = nn.Linear(1536, 256)
        self.lstm = nn.LSTM(256, 128, 2, bidirectional=True, batch_first=True)
        self.fc2 = nn.Linear(256, 36)

    def forward(self, x):
        x = self.cnn1(x)
        x = F.relu(x)
        x = self.pool(x)
        x = self.cnn2(x)
        x = self.batchnorm1(x)
        x = F.relu(x)
        x = self.pool(x)
        x = self.cnn3(x) 
        x = F.relu(x)
        x = self.pool2(x)
        x = self.cnn4(x)
        x = self.batchnorm2(x)
        x = F.relu(x)
        x = self.pool2(x)
        x = x.permute(0, 3, 1, 2)
        x = x.view(x.size(0), x.size(1), -1)
        x = self.fc1(x)
        x, _ = self.lstm(x)
        x = self.fc2(x)
        return x.reshape(-1, 5, 36)

In [8]:
model = CaptchaModel()
model.to("cuda:0")

CaptchaModel(
  (cnn1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (cnn2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (batchnorm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (cnn3): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool2): MaxPool2d(kernel_size=(2, 3), stride=(2, 3), padding=0, dilation=1, ceil_mode=False)
  (cnn4): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (batchnorm2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=1536, out_features=256, bias=True)
  (lstm): LSTM(256, 128, num_layers=2, batch_first=True, bidirectional=True)
  (fc2): Linear(in_features=256, out_features=36, bias=True)
)

In [9]:
summary(model)

Layer (type:depth-idx)                   Param #
CaptchaModel                             --
├─Conv2d: 1-1                            1,792
├─MaxPool2d: 1-2                         --
├─Conv2d: 1-3                            73,856
├─BatchNorm2d: 1-4                       256
├─Conv2d: 1-5                            295,168
├─MaxPool2d: 1-6                         --
├─Conv2d: 1-7                            1,180,160
├─BatchNorm2d: 1-8                       1,024
├─Linear: 1-9                            393,472
├─LSTM: 1-10                             790,528
├─Linear: 1-11                           9,252
Total params: 2,745,508
Trainable params: 2,745,508
Non-trainable params: 0

In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())#SGD(model.parameters(), lr=0.05, momentum=0.6)

In [11]:
def train_model(model, criterion, optimizer, trainloader, num_epochs=40):
    model.train(True)
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(tqdm(trainloader)):
            inputs, labels = data
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()

            optimizer.step()

            running_loss += loss.item()
        print('Epoch {0}/{1}, iteration {2}, loss: {3:.3f}'.format(epoch + 1, num_epochs, i + 1, 
                                                                  running_loss / i))
        print()

    print('Finished Training')
    
    return model

In [12]:
model = train_model(model, criterion, optimizer, train_loader, num_epochs=50)

100%|██████████| 27/27 [00:00<00:00, 33.49it/s]


Epoch 1/50, iteration 27, loss: 0.220



100%|██████████| 27/27 [00:00<00:00, 37.40it/s]


Epoch 2/50, iteration 27, loss: 0.157



100%|██████████| 27/27 [00:00<00:00, 37.44it/s]


Epoch 3/50, iteration 27, loss: 0.122



100%|██████████| 27/27 [00:00<00:00, 37.45it/s]


Epoch 4/50, iteration 27, loss: 0.094



100%|██████████| 27/27 [00:00<00:00, 37.39it/s]


Epoch 5/50, iteration 27, loss: 0.074



100%|██████████| 27/27 [00:00<00:00, 37.45it/s]


Epoch 6/50, iteration 27, loss: 0.059



100%|██████████| 27/27 [00:00<00:00, 37.39it/s]


Epoch 7/50, iteration 27, loss: 0.051



100%|██████████| 27/27 [00:00<00:00, 37.47it/s]


Epoch 8/50, iteration 27, loss: 0.040



100%|██████████| 27/27 [00:00<00:00, 37.39it/s]


Epoch 9/50, iteration 27, loss: 0.034



100%|██████████| 27/27 [00:00<00:00, 37.43it/s]


Epoch 10/50, iteration 27, loss: 0.031



100%|██████████| 27/27 [00:00<00:00, 37.42it/s]


Epoch 11/50, iteration 27, loss: 0.029



100%|██████████| 27/27 [00:00<00:00, 37.40it/s]


Epoch 12/50, iteration 27, loss: 0.029



100%|██████████| 27/27 [00:00<00:00, 37.42it/s]


Epoch 13/50, iteration 27, loss: 0.027



100%|██████████| 27/27 [00:00<00:00, 37.40it/s]


Epoch 14/50, iteration 27, loss: 0.025



100%|██████████| 27/27 [00:00<00:00, 37.42it/s]


Epoch 15/50, iteration 27, loss: 0.025



100%|██████████| 27/27 [00:00<00:00, 37.40it/s]


Epoch 16/50, iteration 27, loss: 0.024



100%|██████████| 27/27 [00:00<00:00, 37.42it/s]


Epoch 17/50, iteration 27, loss: 0.024



100%|██████████| 27/27 [00:00<00:00, 37.41it/s]


Epoch 18/50, iteration 27, loss: 0.024



100%|██████████| 27/27 [00:00<00:00, 37.39it/s]


Epoch 19/50, iteration 27, loss: 0.024



100%|██████████| 27/27 [00:00<00:00, 37.39it/s]


Epoch 20/50, iteration 27, loss: 0.023



100%|██████████| 27/27 [00:00<00:00, 37.42it/s]


Epoch 21/50, iteration 27, loss: 0.023



100%|██████████| 27/27 [00:00<00:00, 37.41it/s]


Epoch 22/50, iteration 27, loss: 0.023



100%|██████████| 27/27 [00:00<00:00, 37.43it/s]


Epoch 23/50, iteration 27, loss: 0.023



100%|██████████| 27/27 [00:00<00:00, 37.39it/s]


Epoch 24/50, iteration 27, loss: 0.022



100%|██████████| 27/27 [00:00<00:00, 37.34it/s]


Epoch 25/50, iteration 27, loss: 0.022



100%|██████████| 27/27 [00:00<00:00, 37.37it/s]


Epoch 26/50, iteration 27, loss: 0.022



100%|██████████| 27/27 [00:00<00:00, 37.35it/s]


Epoch 27/50, iteration 27, loss: 0.022



100%|██████████| 27/27 [00:00<00:00, 37.36it/s]


Epoch 28/50, iteration 27, loss: 0.022



100%|██████████| 27/27 [00:00<00:00, 37.36it/s]


Epoch 29/50, iteration 27, loss: 0.022



100%|██████████| 27/27 [00:00<00:00, 37.38it/s]


Epoch 30/50, iteration 27, loss: 0.023



100%|██████████| 27/27 [00:00<00:00, 37.36it/s]


Epoch 31/50, iteration 27, loss: 0.025



100%|██████████| 27/27 [00:00<00:00, 37.36it/s]


Epoch 32/50, iteration 27, loss: 0.025



100%|██████████| 27/27 [00:00<00:00, 37.36it/s]


Epoch 33/50, iteration 27, loss: 0.024



100%|██████████| 27/27 [00:00<00:00, 37.37it/s]


Epoch 34/50, iteration 27, loss: 0.023



100%|██████████| 27/27 [00:00<00:00, 37.36it/s]


Epoch 35/50, iteration 27, loss: 0.023



100%|██████████| 27/27 [00:00<00:00, 37.38it/s]


Epoch 36/50, iteration 27, loss: 0.023



100%|██████████| 27/27 [00:00<00:00, 37.36it/s]


Epoch 37/50, iteration 27, loss: 0.023



100%|██████████| 27/27 [00:00<00:00, 37.36it/s]


Epoch 38/50, iteration 27, loss: 0.022



100%|██████████| 27/27 [00:00<00:00, 37.36it/s]


Epoch 39/50, iteration 27, loss: 0.022



100%|██████████| 27/27 [00:00<00:00, 37.34it/s]


Epoch 40/50, iteration 27, loss: 0.022



100%|██████████| 27/27 [00:00<00:00, 37.36it/s]


Epoch 41/50, iteration 27, loss: 0.022



100%|██████████| 27/27 [00:00<00:00, 37.39it/s]


Epoch 42/50, iteration 27, loss: 0.022



100%|██████████| 27/27 [00:00<00:00, 37.34it/s]


Epoch 43/50, iteration 27, loss: 0.022



100%|██████████| 27/27 [00:00<00:00, 37.39it/s]


Epoch 44/50, iteration 27, loss: 0.022



100%|██████████| 27/27 [00:00<00:00, 37.37it/s]


Epoch 45/50, iteration 27, loss: 0.022



100%|██████████| 27/27 [00:00<00:00, 37.39it/s]


Epoch 46/50, iteration 27, loss: 0.021



100%|██████████| 27/27 [00:00<00:00, 37.34it/s]


Epoch 47/50, iteration 27, loss: 0.021



100%|██████████| 27/27 [00:00<00:00, 37.39it/s]


Epoch 48/50, iteration 27, loss: 0.021



100%|██████████| 27/27 [00:00<00:00, 37.37it/s]


Epoch 49/50, iteration 27, loss: 0.021



100%|██████████| 27/27 [00:00<00:00, 37.39it/s]

Epoch 50/50, iteration 27, loss: 0.021

Finished Training





In [13]:
def from_idx_to_str(idxs, dictionary):
    answers = []
    for els in idxs:
        answer = ""
        for el in els.cpu().numpy():
            answer += dictionary[el]
        answers.append(answer)
    return answers

In [14]:
import numpy as np

In [15]:
def model_cer(model, testloader):
    errors = []
    num = 0
    model.train(False)
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            
            output = model.forward(images)
            pred = torch.argmax(output, dim = 2)
            target = torch.argmax(labels, dim = 2)
            preds = from_idx_to_str(pred.squeeze(), dictionary)
            targets = from_idx_to_str(target.squeeze(), dictionary)
            num += len(pred)
            errors.append(torchmetrics.functional.char_error_rate(preds, targets))
    print('CER of the network on the {} images: {} %'.format(num, 100 * np.mean(errors)))

In [16]:
model_cer(model, train_loader)

CER of the network on the 856 images: 2.0293209701776505 %


In [17]:
model_cer(model, test_loader)

CER of the network on the 214 images: 5.113636702299118 %


##  Анализ ошибок

In [18]:
from collections import Counter

In [19]:
errors = []
for data in test_loader:
    images, labels = data

    output = model.forward(images)
    preds = torch.argmax(output, dim = 2)
    targets = torch.argmax(labels, dim = 2)
    pred = from_idx_to_str(preds.squeeze(), dictionary)
    target = from_idx_to_str(targets.squeeze(), dictionary)
    for i in range(len(pred)):
        for j in range(len(pred[i])):
            if pred[i][j] != target[i][j]:
                errors.append(pred[i][j] + target[i][j])
Counter(errors)

Counter({'px': 1,
         'mn': 2,
         'pb': 2,
         'bd': 1,
         'yw': 2,
         'df': 2,
         '83': 9,
         'ew': 1,
         'bn': 3,
         'f4': 1,
         '5c': 1,
         'nm': 7,
         '58': 1,
         'ce': 1,
         '72': 1,
         'yx': 1,
         'bp': 2,
         'pw': 1,
         'cx': 1,
         '84': 1,
         '2x': 1,
         '48': 2,
         '53': 1,
         '27': 1,
         '5f': 1,
         'd3': 1,
         'db': 1,
         '23': 1,
         '56': 1,
         'wx': 1,
         'xc': 1,
         '4f': 1,
         '85': 1})

Видно, что чаще всего ошибки происходят в похожих символах (mn, 83). При обучении можно попробовать сэмплировать данные с весами(брать чаще изображения с данными спорными символами, чтобы сеть больше на них училась).