In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import cv2
import os
import numpy as np

# Hyperparameters
IMG_HEIGHT, IMG_WIDTH = 300, 300
BATCH_SIZE = 32
EPOCHS = 30
LEARNING_RATE = 1e-4
MAX_TEXT_LENGTH = 10  # Adjust based on dataset

# Character set
CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
CHAR2IDX = {c: i + 1 for i, c in enumerate(CHARS)}  # Leave 0 for blank token
IDX2CHAR = {i: c for c, i in CHAR2IDX.items()}

# Dataset
class CaptchaDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.data = []
        for word in os.listdir(root_dir):
            word_path = os.path.join(root_dir, word)
            easy = 0
            hard_hollow = 0
            hard_normal = 0
            green = 0
            red = 0
            if os.path.isdir(word_path):
                for img_file in os.listdir(word_path):
                    if img_file.startswith("easy"):
                        easy += 1
                        if easy > 1:
                            continue
                    elif img_file.startswith("hardhollow"):
                        hard_hollow += 1
                        if hard_hollow > 5:
                            continue
                    elif img_file.startswith("hardnormal"):
                        hard_normal += 1
                        if hard_normal > 5:
                            continue
                    elif img_file.startswith("green"):
                        green += 1
                        if green > 5:
                            continue
                    elif img_file.startswith("red"):
                        red += 1
                        if red > 10:
                            continue
                    self.data.append((os.path.join(word_path, img_file), word))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        image = cv2.resize(image, (IMG_WIDTH, IMG_HEIGHT))
        image = np.expand_dims(image, axis=0) / 255.0  # Normalize
        label_encoded = [CHAR2IDX[c] for c in label]
        return torch.FloatTensor(image), torch.LongTensor(label_encoded)

# Model
class CRNN(nn.Module):
    def __init__(self, num_classes=len(CHARS) + 1):
        super(CRNN, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.ReLU(),
            nn.MaxPool2d((2, 2)),
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(),
            nn.MaxPool2d((2, 2))
        )
        self.rnn = nn.LSTM(512 * (IMG_WIDTH // 16), 256, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.cnn(x)
        x = x.permute(0, 2, 3, 1).contiguous()
        x = x.view(x.size(0), x.size(1), -1)
        x, _ = self.rnn(x)
        x = self.fc(x)
        return x

# Loss and Training
criterion = nn.CTCLoss(blank=0, zero_infinity=True)

def collate_fn(batch):
    images, labels = zip(*batch)
    images = torch.stack(images)  # Stack images normally
    labels = pad_sequence(labels, batch_first=True, padding_value=0)  # Pad labels
    return images, labels

def train_model():
    dataset = CaptchaDataset("dataset", transform=None)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
    model = CRNN().cuda()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    for epoch in range(EPOCHS):
        total_loss = 0
        for images, labels in dataloader:
            images, labels = images.cuda(), labels.cuda()
            
            optimizer.zero_grad()
            logits = model(images)
            
            input_lengths = torch.full((logits.size(0),), logits.size(1), dtype=torch.long).cuda()
            target_lengths = torch.tensor([len(lbl) for lbl in labels], dtype=torch.long, device=logits.device)
            if logits.size(0) > 0:
                loss = criterion(logits.log_softmax(2).permute(1, 0, 2), labels, input_lengths, target_lengths)
                loss.backward()
                optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss / len(dataloader):.4f}")
    
    torch.save(model.state_dict(), "captcha_model.pth")
    print("Model saved!")

In [9]:
train_model()

Epoch 1/30, Loss: 3.3651
Epoch 2/30, Loss: 2.6109
Epoch 3/30, Loss: 2.5021
Epoch 4/30, Loss: 2.4170
Epoch 5/30, Loss: 2.3104
Epoch 6/30, Loss: 2.1981
Epoch 7/30, Loss: 2.1529
Epoch 8/30, Loss: 2.1148
Epoch 9/30, Loss: 2.0908
Epoch 10/30, Loss: 2.0655
Epoch 11/30, Loss: 2.0483
Epoch 12/30, Loss: 2.0228
Epoch 13/30, Loss: 2.0047
Epoch 14/30, Loss: 1.9751
Epoch 15/30, Loss: 1.9440
Epoch 16/30, Loss: 1.9118
Epoch 17/30, Loss: 1.8769
Epoch 18/30, Loss: 1.8335
Epoch 19/30, Loss: 1.7954
Epoch 20/30, Loss: 1.7410
Epoch 21/30, Loss: 1.6859
Epoch 22/30, Loss: 1.6176
Epoch 23/30, Loss: 1.5467
Epoch 24/30, Loss: 1.4718
Epoch 25/30, Loss: 1.3926
Epoch 26/30, Loss: 1.3139
Epoch 27/30, Loss: 1.2299
Epoch 28/30, Loss: 1.1408
Epoch 29/30, Loss: 1.0563
Epoch 30/30, Loss: 0.9704
Model saved!
