In [102]:
# pylint: disable=unused-import
import os
import random

import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from sklearn.model_selection import train_test_split




In [15]:
RANDOM_STATE = 42

In [63]:
images = os.listdir('./dataset')
train_images, test_images = train_test_split(images, shuffle=True, train_size=0.75, random_state=RANDOM_STATE)
print('train size:', len(train_images))
print('test size:', len(test_images))


chars = list(map(lambda x: x[:-4], images))
chars = ''.join(chars)
chars = set(chars)
NUM_CHARS = len(chars)
print('unique chars:', len(chars))

idx2char = {k: v for k, v in enumerate(chars, start=0)}
char2idx = {k: v for v, k in idx2char.items()}

train size: 802
test size: 268
unique chars: 19


### Dataset and DataLoader

In [97]:
class CaptchaDataset(Dataset):
    def __init__(self, images, img_dir='./dataset/', transform=None):
        self.images = images
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        image = Image.open(img_path).convert('L')
        label = self.images[idx][:-4]
        label = torch.Tensor([char2idx[char] for char in label])
        if self.transform:
            image = self.transform(image)
        return image, label

In [98]:
transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=0.5, std=0.25)
        ])

train_dataset = CaptchaDataset(train_images, transform=transform)
test_dataset = CaptchaDataset(test_images, transform=transform)

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

In [99]:
x, y =next(iter(train_dataloader))
x.shape, y.shape


(torch.Size([64, 1, 50, 200]), torch.Size([64, 5]))

### Model

In [96]:
class OCRModel(nn.Module):
    def __init__(self):
        super(OCRModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(kernel_size=(2, 2))

        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.pool2 = nn.MaxPool2d(kernel_size=(2, 2))

        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.pool3 = nn.MaxPool2d(kernel_size=(2, 3))

        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(512)
        self.pool4 = nn.MaxPool2d(kernel_size=(2, 3))
        
        self.linear1 = nn.Linear(1536, 64)
        self.lstm = nn.LSTM(64, 32, bidirectional=True, num_layers=2, dropout=0.2, batch_first=True)
        self.linear2 = nn.Linear(64, NUM_CHARS)


    def forward(self, x):
        #  [batch, 1, 50, 200]
        x = self.pool1(F.relu(self.bn1(self.conv1(x)))) #  [batch, 64, 25, 100]
        x = F.relu(self.pool2(self.bn2(self.conv2(x)))) #  [batch, 128, 12, 50]
        x = F.relu(self.pool3(self.bn3(self.conv3(x)))) #  [batch, 256, 6, 16]
        x = self.pool4(self.bn4(self.conv4(x))) #  [batch, 512, 3, 5]
        x = x.permute(0, 3, 1, 2) #  [batch, 5, 512, 3]
        x = x.view(x.shape[0], x.shape[1], -1) #  [batch, 5, 1536]
        x = F.relu(self.linear1(x)) #  [batch, 5, 64]
        x, _ = self.lstm(x) #  [batch, 5, 64]
        x = self.linear2(x) #  [batch, 5, 19]
        return x


if __name__ == "__main__":
    model = OCRModel()
    model.eval()
    x, y = next(iter(train_dataloader))
    preds = model(x)
    print(preds.shape)
    

torch.Size([64, 5, 19])


In [118]:
model = OCRModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

def train_model(model, criterion, optimizer, trainloader, num_epochs=5):
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            preds = model(inputs)
            preds = preds.permute(0, 2, 1)
            loss = criterion(preds, labels.long())
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 2000 == 1999:
                print('Epoch {0}/{1}, iteration {2}, loss: {3:.3f}'.format(epoch + 1, num_epochs, i + 1, 
                                                                          running_loss / 2000))
                running_loss = 0.0
        print()

    print('Finished Training')


train_model(model, criterion, optimizer, train_dataloader, num_epochs=5)


torch.Size([64, 19, 5]) torch.Size([64, 5])
torch.Size([64, 19, 5]) torch.Size([64, 5])
torch.Size([64, 19, 5]) torch.Size([64, 5])
torch.Size([64, 19, 5]) torch.Size([64, 5])
torch.Size([64, 19, 5]) torch.Size([64, 5])
torch.Size([64, 19, 5]) torch.Size([64, 5])
torch.Size([64, 19, 5]) torch.Size([64, 5])
torch.Size([64, 19, 5]) torch.Size([64, 5])
torch.Size([64, 19, 5]) torch.Size([64, 5])


KeyboardInterrupt: 

RuntimeError: shape '[0, 2, 1]' is invalid for input of size 6080

In [30]:
model.eval()
preds = model(x)
preds.shape

torch.Size([64, 512, 3, 12])

In [41]:
class CaptchaModel(nn.Module):
    def __init__(self, num_chars):
        super(CaptchaModel, self).__init__()
        self.conv_1 = nn.Conv2d(3, 128, kernel_size=(3, 6), padding=(1, 1))
        self.pool_1 = nn.MaxPool2d(kernel_size=(2, 2))
        self.conv_2 = nn.Conv2d(128, 64, kernel_size=(3, 6), padding=(1, 1))
        self.pool_2 = nn.MaxPool2d(kernel_size=(2, 2))
        self.linear_1 = nn.Linear(1152, 64)
        self.drop_1 = nn.Dropout(0.2)
        self.lstm = nn.GRU(64, 32, bidirectional=True, num_layers=2, dropout=0.25, batch_first=True)
        self.output = nn.Linear(64, num_chars + 1)

    def forward(self, images, targets=None):
        bs, _, _, _ = images.size()
        x = F.relu(self.conv_1(images))
        x = self.pool_1(x)
        x = F.relu(self.conv_2(x))
        x = self.pool_2(x)
        print(x.shape)
        x = x.permute(0, 3, 1, 2)
        print(x.shape)
        x = x.view(bs, x.size(1), -1)
        print(x.shape)
        x = F.relu(self.linear_1(x))
        print(x.shape)
        x = self.drop_1(x)
        x, _ = self.lstm(x)
        x = self.output(x)
        x = x.permute(1, 0, 2)

        if targets is not None:
            log_probs = F.log_softmax(x, 2)
            input_lengths = torch.full(
                size=(bs,), fill_value=log_probs.size(0), dtype=torch.int32
            )
            target_lengths = torch.full(
                size=(bs,), fill_value=targets.size(1), dtype=torch.int32
            )
            loss = nn.CTCLoss(blank=0)(
                log_probs, targets, input_lengths, target_lengths
            )
            return x, loss

        return x, None


if __name__ == "__main__":
    cm = CaptchaModel(19)
    img = torch.rand((1, 3, 75, 300))
    x, _ = cm(img, torch.rand((1, 5)))

torch.Size([1, 64, 18, 72])
torch.Size([1, 72, 64, 18])
torch.Size([1, 72, 1152])
torch.Size([1, 72, 64])
