In [1]:
import glob
import os

import albumentations
import numpy as np
import torch
from PIL import Image
from PIL import ImageFile
from sklearn import metrics
from sklearn import model_selection
from sklearn import preprocessing
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

In [2]:
!gdown --id 1qm0Ty5onZX3H1Our7TvaCFOCvNu3wvyw

Downloading...
From: https://drive.google.com/uc?id=1qm0Ty5onZX3H1Our7TvaCFOCvNu3wvyw
To: /content/captcha_images_v2.zip
100% 9.08M/9.08M [00:00<00:00, 19.5MB/s]


In [3]:
!unzip captcha_images_v2.zip

Archive:  captcha_images_v2.zip
   creating: captcha_images_v2/
 extracting: captcha_images_v2/ydd3g.png  
 extracting: captcha_images_v2/36nx4.png  
 extracting: captcha_images_v2/3bnyf.png  
 extracting: captcha_images_v2/8y6b3.png  
 extracting: captcha_images_v2/268g2.png  
 extracting: captcha_images_v2/mnef5.png  
 extracting: captcha_images_v2/5p8fm.png  
 extracting: captcha_images_v2/bxxfc.png  
 extracting: captcha_images_v2/8ypdn.png  
 extracting: captcha_images_v2/gpxng.png  
 extracting: captcha_images_v2/e4gd7.png  
 extracting: captcha_images_v2/pbpgc.png  
 extracting: captcha_images_v2/m4g8g.png  
  inflating: captcha_images_v2/c43b4.png  
 extracting: captcha_images_v2/gwnm6.png  
 extracting: captcha_images_v2/w4x2m.png  
 extracting: captcha_images_v2/npxb7.png  
 extracting: captcha_images_v2/445cc.png  
 extracting: captcha_images_v2/pg2pm.png  
 extracting: captcha_images_v2/wc2bd.png  
 extracting: captcha_images_v2/mc8w2.png  
 extracting: captcha_images_v2/67

### Dataloader

In [4]:
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [5]:
class ClassificationDataset:
    def __init__(self, image_paths, targets, resize=None):
        # resize = (height, width)
        self.image_paths = image_paths
        self.targets = targets
        self.resize = resize

        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
        self.aug = albumentations.Compose(
            [
                albumentations.Normalize(
                    mean, std, max_pixel_value=255.0, always_apply=True
                )
            ]
        )

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, item):
        image = Image.open(self.image_paths[item]).convert("RGB")
        targets = self.targets[item]

        if self.resize is not None:
            image = image.resize(
                (self.resize[1], self.resize[0]), resample=Image.BILINEAR
            )

        image = np.array(image)
        augmented = self.aug(image=image)
        image = augmented["image"]
        image = np.transpose(image, (2, 0, 1)).astype(np.float32)

        return {
            "images": torch.tensor(image, dtype=torch.float),
            "targets": torch.tensor(targets, dtype=torch.long),
        }

### Model

In [6]:
class CaptchaModel(nn.Module):
    def __init__(self, num_chars):
        super(CaptchaModel, self).__init__()
        self.conv_1 = nn.Conv2d(3, 128, kernel_size=(3, 6), padding=(1, 1))
        self.pool_1 = nn.MaxPool2d(kernel_size=(2, 2))
        self.conv_2 = nn.Conv2d(128, 64, kernel_size=(3, 6), padding=(1, 1))
        self.pool_2 = nn.MaxPool2d(kernel_size=(2, 2))
        self.linear_1 = nn.Linear(1152, 64)
        self.drop_1 = nn.Dropout(0.2)
        self.lstm = nn.GRU(64, 32, bidirectional=True, num_layers=2, dropout=0.25, batch_first=True)
        self.output = nn.Linear(64, num_chars + 1)

    def forward(self, images, targets=None):
        bs, _, _, _ = images.size()
        x = F.relu(self.conv_1(images))
        x = self.pool_1(x)
        x = F.relu(self.conv_2(x))
        x = self.pool_2(x)
        x = x.permute(0, 3, 1, 2)
        x = x.view(bs, x.size(1), -1)
        x = F.relu(self.linear_1(x))
        x = self.drop_1(x)
        x, _ = self.lstm(x)
        x = self.output(x)
        x = x.permute(1, 0, 2)

        if targets is not None:
            log_probs = F.log_softmax(x, 2)
            input_lengths = torch.full(
                size=(bs,), fill_value=log_probs.size(0), dtype=torch.int32
            )
            target_lengths = torch.full(
                size=(bs,), fill_value=targets.size(1), dtype=torch.int32
            )
            loss = nn.CTCLoss(blank=0)(
                log_probs, targets, input_lengths, target_lengths
            )
            return x, loss

        return x, None

### Train

In [7]:
DATA_DIR = "captcha_images_v2"
BATCH_SIZE = 2
IMAGE_WIDTH = 100
IMAGE_HEIGHT = 75
NUM_WORKERS = 8
EPOCHS = 10
DEVICE = "cuda"

In [8]:
def train_fn(model, data_loader, optimizer):
    model.train()
    fin_loss = 0
    tk0 = tqdm(data_loader, total=len(data_loader))
    for data in tk0:
        for key, value in data.items():
            data[key] = value.to(DEVICE)
        optimizer.zero_grad()
        _, loss = model(**data)
        loss.backward()
        optimizer.step()
        fin_loss += loss.item()
    return fin_loss / len(data_loader)


def eval_fn(model, data_loader):
    model.eval()
    fin_loss = 0
    fin_preds = []
    tk0 = tqdm(data_loader, total=len(data_loader))
    for data in tk0:
        for key, value in data.items():
            data[key] = value.to(DEVICE)
        batch_preds, loss = model(**data)
        fin_loss += loss.item()
        fin_preds.append(batch_preds)
    return fin_preds, fin_loss / len(data_loader)



def remove_duplicates(x):
    if len(x) < 2:
        return x
    fin = ""
    for j in x:
        if fin == "":
            fin = j
        else:
            if j == fin[-1]:
                continue
            else:
                fin = fin + j
    return fin


def decode_predictions(preds, encoder):
    preds = preds.permute(1, 0, 2)
    preds = torch.softmax(preds, 2)
    preds = torch.argmax(preds, 2)
    preds = preds.detach().cpu().numpy()
    cap_preds = []
    for j in range(preds.shape[0]):
        temp = []
        for k in preds[j, :]:
            k = k - 1
            if k == -1:
                temp.append("§")
            else:
                p = encoder.inverse_transform([k])[0]
                temp.append(p)
        tp = "".join(temp).replace("§", "")
        cap_preds.append(remove_duplicates(tp))
    return cap_preds


In [9]:
def run_training():
    image_files = glob.glob(os.path.join(DATA_DIR, "*.png"))
    targets_orig = [x.split("/")[-1][:-4] for x in image_files]
    targets = [[c for c in x] for x in targets_orig]
    targets_flat = [c for clist in targets for c in clist]

    lbl_enc = preprocessing.LabelEncoder()
    lbl_enc.fit(targets_flat)
    targets_enc = [lbl_enc.transform(x) for x in targets]
    targets_enc = np.array(targets_enc)
    targets_enc = targets_enc + 1

    (
        train_imgs,
        test_imgs,
        train_targets,
        test_targets,
        _,
        test_targets_orig,
    ) = model_selection.train_test_split(
        image_files, targets_enc, targets_orig, test_size=0.1, random_state=42
    )

    train_dataset = ClassificationDataset(
        image_paths=train_imgs,
        targets=train_targets,
        resize=(IMAGE_HEIGHT, IMAGE_WIDTH),
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        shuffle=True,
    )
    test_dataset = ClassificationDataset(
        image_paths=test_imgs,
        targets=test_targets,
        resize=(IMAGE_HEIGHT, IMAGE_WIDTH),
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        shuffle=False,
    )

    model = CaptchaModel(num_chars=len(lbl_enc.classes_))
    model.to(DEVICE)

    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, factor=0.8, patience=5, verbose=True
    )
    for epoch in range(EPOCHS):
        train_loss = train_fn(model, train_loader, optimizer)
        valid_preds, test_loss = eval_fn(model, test_loader)
        valid_captcha_preds = []
        for vp in valid_preds:
            current_preds = decode_predictions(vp, lbl_enc)
            valid_captcha_preds.extend(current_preds)
        combined = list(zip(test_targets_orig, valid_captcha_preds))
        print(combined[:10])
        test_dup_rem = [remove_duplicates(c) for c in test_targets_orig]
        accuracy = metrics.accuracy_score(test_dup_rem, valid_captcha_preds)
        print(
            f"Epoch={epoch}, Train Loss={train_loss}, Test Loss={test_loss} Accuracy={accuracy}"
        )
        scheduler.step(test_loss)

In [10]:
run_training()

100%|██████████| 468/468 [00:05<00:00, 85.29it/s] 
100%|██████████| 52/52 [00:00<00:00, 110.48it/s]


[('67dey', ''), ('7xd5m', ''), ('xp24p', ''), ('mmy5n', ''), ('42xpy', ''), ('8d2nd', ''), ('63824', ''), ('6n443', ''), ('cfc56', ''), ('6fn84', '')]
Epoch=0, Train Loss=3.558569415512248, Test Loss=3.2777999960459194 Accuracy=0.0


100%|██████████| 468/468 [00:04<00:00, 113.76it/s]
100%|██████████| 52/52 [00:00<00:00, 132.98it/s]


[('67dey', ''), ('7xd5m', ''), ('xp24p', ''), ('mmy5n', ''), ('42xpy', ''), ('8d2nd', ''), ('63824', ''), ('6n443', ''), ('cfc56', ''), ('6fn84', '')]
Epoch=1, Train Loss=3.228657102992392, Test Loss=3.2219000412867618 Accuracy=0.0


100%|██████████| 468/468 [00:03<00:00, 122.96it/s]
100%|██████████| 52/52 [00:00<00:00, 150.58it/s]


[('67dey', ''), ('7xd5m', ''), ('xp24p', ''), ('mmy5n', ''), ('42xpy', ''), ('8d2nd', ''), ('63824', '4'), ('6n443', ''), ('cfc56', ''), ('6fn84', '')]
Epoch=2, Train Loss=2.99124809106191, Test Loss=2.5455210163043094 Accuracy=0.0


100%|██████████| 468/468 [00:03<00:00, 124.85it/s]
100%|██████████| 52/52 [00:00<00:00, 140.14it/s]


[('67dey', '7dy'), ('7xd5m', '7d5'), ('xp24p', '24g'), ('mmy5n', 'ny5n'), ('42xpy', '42gy'), ('8d2nd', 'd2d'), ('63824', '324'), ('6n443', '6n42'), ('cfc56', 'f'), ('6fn84', 'fb4')]
Epoch=3, Train Loss=2.2109287888066382, Test Loss=1.452664689375804 Accuracy=0.028846153846153848


100%|██████████| 468/468 [00:03<00:00, 117.47it/s]
100%|██████████| 52/52 [00:00<00:00, 134.78it/s]


[('67dey', '67dwy'), ('7xd5m', '7d5n'), ('xp24p', 'xp24p'), ('mmy5n', 'ny5n'), ('42xpy', '42cpy'), ('8d2nd', '8d2nd'), ('63824', '63824'), ('6n443', '6n43'), ('cfc56', 'cfc56'), ('6fn84', '6fn84')]
Epoch=4, Train Loss=1.2354020702405872, Test Loss=0.6632180299896461 Accuracy=0.4807692307692308


100%|██████████| 468/468 [00:03<00:00, 126.45it/s]
100%|██████████| 52/52 [00:00<00:00, 139.89it/s]


[('67dey', '67dey'), ('7xd5m', '7xd5n'), ('xp24p', 'xp24p'), ('mmy5n', 'ny5n'), ('42xpy', '42xpy'), ('8d2nd', '8d2nd'), ('63824', '63824'), ('6n443', '6n43'), ('cfc56', 'cfc56'), ('6fn84', '6fn84')]
Epoch=5, Train Loss=0.6800088981150562, Test Loss=0.36412019626452374 Accuracy=0.6923076923076923


100%|██████████| 468/468 [00:03<00:00, 123.20it/s]
100%|██████████| 52/52 [00:00<00:00, 129.46it/s]


[('67dey', '67dey'), ('7xd5m', '7xd5n'), ('xp24p', 'xp24p'), ('mmy5n', 'nmy5n'), ('42xpy', '42xpy'), ('8d2nd', '8d2d'), ('63824', '63824'), ('6n443', '6n43'), ('cfc56', 'cfc6'), ('6fn84', '6f84')]
Epoch=6, Train Loss=0.44128436279984623, Test Loss=0.23963022289367822 Accuracy=0.7211538461538461


100%|██████████| 468/468 [00:03<00:00, 119.55it/s]
100%|██████████| 52/52 [00:00<00:00, 128.16it/s]


[('67dey', '67dey'), ('7xd5m', '7xd5n'), ('xp24p', 'xp24p'), ('mmy5n', 'nmy5n'), ('42xpy', '42xpy'), ('8d2nd', '8d2nd'), ('63824', '63824'), ('6n443', '6n43'), ('cfc56', 'cfc56'), ('6fn84', '6fn84')]
Epoch=7, Train Loss=0.3320102366244691, Test Loss=0.17967865272210196 Accuracy=0.7211538461538461


100%|██████████| 468/468 [00:03<00:00, 123.31it/s]
100%|██████████| 52/52 [00:00<00:00, 137.80it/s]


[('67dey', '67dey'), ('7xd5m', '7xd5m'), ('xp24p', 'xp24p'), ('mmy5n', 'my5n'), ('42xpy', '42xpy'), ('8d2nd', '8d2nd'), ('63824', '63824'), ('6n443', '6n43'), ('cfc56', 'cfc56'), ('6fn84', '6fn84')]
Epoch=8, Train Loss=0.25831728149205446, Test Loss=0.11745786494933642 Accuracy=0.8942307692307693


100%|██████████| 468/468 [00:03<00:00, 121.24it/s]
100%|██████████| 52/52 [00:00<00:00, 139.13it/s]


[('67dey', '67dey'), ('7xd5m', '7xd5m'), ('xp24p', 'xp24p'), ('mmy5n', 'my5n'), ('42xpy', '42xpy'), ('8d2nd', '8d2nd'), ('63824', '63824'), ('6n443', '6n43'), ('cfc56', 'cfc56'), ('6fn84', '6fn84')]
Epoch=9, Train Loss=0.20189004184471238, Test Loss=0.0892209647080073 Accuracy=0.9134615384615384
