# CNN Model (Notebook version)
## 1) Data Pre-processing
Import the libraries

In [None]:
import os
import cv2
import random
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import numpy as np
import itertools
from pathlib import Path
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))

In [None]:
# Parameters:
NUMBER = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
ALPHABET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
TABLE = NUMBER+ALPHABET # The table for the captcha
LEN_OF_TABLE=36
BATCH_SIZE=100
LEN_OF_CAPTCHA=6
LEARNING_RATE=0.001

Load the data, please modify the path by yourself

In [None]:
# Load the data from the Google Drive
# data_dir = Path("/content/drive/MyDrive/Data")

# path of data set for local
data_dir = Path("D:/Task/FS")

images = list(data_dir.glob("*.jpg")) #the size of dataset
print("Number of images found: ", len(images))

The CAPTCHA samples in the data set

In [None]:
sample_images =images[:6] 
_, ax = plt.subplots(2, 3, figsize=(20, 5))
for i in range(6):
    img = cv2.imread(str(sample_images[i]))
    print("Shape of image: ", img.shape)
    ax[i//3, i % 3].imshow(img)
plt.show()

### Customize the data set class

In [None]:
# Convert the captcha into the (6*36,) vector (6 characters,10 numbers + 26 Upper case characters)
# 1 means the captcha has this character in TABLE, 0 means not
def captcha_to_vector(captcha_str):
    captcha_str = captcha_str.upper()
    vector = np.zeros(36*6, dtype=np.float32)
    for i, char in enumerate(captcha_str):
        ascii = ord(char) # Convert char into ascii code
        if 48 <= ascii <= 57:# for numbers
            index = ascii-48
        elif 65 <= ascii <= 90:# for characters
            index = ascii-ord('A')+10
        vector[i*LEN_OF_TABLE+index] = 1.0
    return vector

In [None]:
# Convert the vector into the captcha (the input vector is different from the vector above)
# example: 
#   input:[1,2,34,2,6,7] 
#   output:"23Y378"
def vector_to_captcha(vector):
    captcha_str = ""
    for i in vector:
        captcha_str += TABLE[i]
    return captcha_str

In [None]:
def update_matrix(preds,label,conf_matrix):
    for p,t in zip(preds,label):
        conf_matrix[TABLE.index(p),TABLE.index(t)]+=1
    return conf_matrix

In [None]:
# Custom dataset
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, images, transform=None, target_transform=None, height=50, width=200):
        self.transform = transform
        self.target_transform = target_transform
        self.images = images
        self.width = width
        self.height = height

    def __getitem__(self, index):
        # get the image with path
        image = cv2.imread(str(self.images[index]))
        label = captcha_to_vector(self.images[index].name.split("_")[0])
        # Apply the transform to the image
        if self.transform is not None:
            image = self.transform(image)
        return image, label

    def __len__(self):
        return len(self.images)

Split and create data sets

In [None]:
random.shuffle(images)
# test data
test_data = images[8000:]  # last 2000 images in data set are for for test

# the part for training
training = images[:8000]
valid_data = training[6000:]  # last 2000 images in training set are for validation
train_data = training[:6000]  # 6000 images for train

print("test set size:", len(test_data))
print("validation set size:", len(valid_data))
print("train set size:", len(train_data))

train_set = CustomDataset(train_data, transform=transforms.ToTensor())
valid_set = CustomDataset(valid_data, transform=transforms.ToTensor())
test_set = CustomDataset(test_data, transform=transforms.ToTensor())

train_dataloader = DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(dataset=test_set, batch_size=1, shuffle=True)
valid_dataloader = DataLoader(dataset=valid_set, batch_size=1, shuffle=True)

## 2) CNN Model

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 48, kernel_size=5, padding=2),
            nn.BatchNorm2d(48),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.layer3 = nn.Sequential(
            nn.Conv2d(48, 64, kernel_size=5, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.layer4 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.layer5 = nn.Linear(64*3*12, 512)
        self.out = nn.Linear(512, 36*6)

    def forward(self, x):
        x = self.layer1(x)  # Input: torch.Size([100, 1, 50, 200])
        x = self.layer2(x)  # Input: torch.Size([100, 32, 25, 100])
        x = self.layer3(x)  # Input: torch.Size([100, 48, 12, 50])
        x = self.layer4(x)  # Input: torch.Size([100, 64, 6, 25])

        # Output: torch.Size([100, 64, 3, 12])
        x = x.view(-1, 64*3*12)
        x = self.layer5(x)
        output = self.out(x)
        # Output: torch.Size([100, 36*6])
        return output

In [None]:
# Initialize the model 
model = CNN().to(device)

In [None]:
def valid(model, valid_dataloader, device):
    num_correct = 0  # the counter for the correct items
    num_total = len(valid_dataloader)  # the counter for the total items
    model.eval()  # set the evaluation state of the model
    with torch.no_grad():
        for _, (images, labels) in enumerate(valid_dataloader):
            images = images.to(device)
            labels = labels.to(device)
            output = model(images)
            labels = labels.reshape(6, 36)
            output = output.reshape(6, 36)
            # get the captcha character index
            labels = torch.argmax(labels, dim=1)
            # get the predict character index
            output = torch.argmax(output, dim=1)
            num_correct += ((labels == output).sum() == 6).sum().item()
        accuracy = num_correct / num_total * 100
        return accuracy


def test(model, test_dataloader, device,conf_matrix):
    num_correct = 0  # the counter for the correct items
    num_total = len(test_dataloader)  # the counter for the total items
    model.eval()  # set the evaluation state of the model

    with torch.no_grad():
        for _, (images, labels) in enumerate(test_dataloader):
            images = images.to(device)
            labels = labels.to(device)
            output = model(images)
            labels = labels.reshape(6, 36)
            output = output.reshape(6, 36)
            # get the captcha character index
            labels = torch.argmax(labels, dim=1)
            # get the predict character index
            output = torch.argmax(output, dim=1)
            num_correct += ((labels == output).sum() == 6).sum().item()
            # New
            conf_matrix=update_matrix(vector_to_captcha(output),vector_to_captcha(labels),conf_matrix)

        accuracy = num_correct / num_total * 100
        return accuracy

In [None]:
def train(model, train_dataloader, valid_dataloader, device):
    model.train()
    criterion = nn.MultiLabelSoftMarginLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    for epoch in range(15):
        model.train()
        for i, (images, labels) in enumerate(train_dataloader):
            images = images.to(device)
            labels = labels.to(device)

            predict = model(images)
            optimizer.zero_grad()
            loss = criterion(predict, labels)
            loss.backward()
            optimizer.step()
        accuracy = valid(model, valid_dataloader, device)
        print("epoch: {} loss: {:.10f} accuracy: {:.4f}".format((epoch+1), loss.item(), accuracy))

In [None]:

print("\nTraining")
train(model, train_dataloader, valid_dataloader, device)

conf_matrix=torch.zeros(36,36)

print("\nTesting")
accuracy = test(model, test_dataloader, device, conf_matrix)
print("Accuracy: {}".format(accuracy))

In [None]:
# _, ax = plt.subplots(2, 3, figsize=(20, 5))
with torch.no_grad():
    for i in range(5):
        image, label = next(iter(test_dataloader))

        image = image.to(device)
        label = label.to(device)
        output = model(image)
        label = label.reshape(6, 36)
        output = output.reshape(6, 36)
        label = torch.argmax(label, dim=1)
        output = torch.argmax(output, dim=1)
        origin = vector_to_captcha(label)
        predict = vector_to_captcha(output)

        print("Origin: "+origin+"   Predict: "+predict)

## Draw the confusion matrix
Reference：https://www.daimajiaoliu.com/daima/4ed46a79a900402

In [None]:
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')
    print(cm)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=90)
    plt.yticks(tick_marks, classes)
    
    plt.axis("equal")
    ax = plt.gca()
    left, right = plt.xlim()
    ax.spines['left'].set_position(('data', left))
    ax.spines['right'].set_position(('data', right))
    for edge_i in ['top', 'bottom', 'right', 'left']:
        ax.spines[edge_i].set_edgecolor("white")

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        num = '{:.2f}'.format(cm[i, j]) if normalize else int(cm[i, j])
        plt.text(j, i, num,
                 verticalalignment='center',
                 horizontalalignment="center",
                 color="white" if num > thresh else "black")
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()

In [None]:
plot_confusion_matrix(conf_matrix.numpy(),TABLE)