### Data preprocessing

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

def resize_and_pad(img, target_size=(40, 120)):
    """
    Resize image proportionally and pad to target size.
    
    Args:
        img: Input image
        target_size: Target (height, width) tuple
    
    Returns:
        Padded image of target size
    """
    target_h, target_w = target_size
    h, w = img.shape[:2]

    # Resize proportionally to fit within target dimensions
    scale = min(target_w / w, target_h / h)
    new_w, new_h = int(w * scale), int(h * scale)
    resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_NEAREST)

    # Create canvas
    padded = np.ones((target_h, target_w), dtype=img.dtype) * 255

    # Center the resized image
    x_offset = (target_w - new_w) // 2
    y_offset = (target_h - new_h) // 2
    padded[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = resized

    return padded


def remove_black_lines(img):
    """
    Remove black strikethrough lines from CAPTCHA images using inpainting.
    
    Args:
        img: Input BGR image
    
    Returns:
        cleaned: Image with black lines removed
        mask_black: Binary mask of removed regions
    """
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)

    # Define HSV range for dark colors (black and dark gray)
    lower_black = np.array([0, 0, 0])
    upper_black = np.array([180, 255, 80])

    # Create mask for black regions
    mask_black = cv2.inRange(hsv, lower_black, upper_black)

    # Inpaint to repair text where black lines were removed
    cleaned = cv2.inpaint(img, mask_black, inpaintRadius=1, flags=cv2.INPAINT_TELEA)

    return cleaned, mask_black





In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import os

class CaptchaDataset(Dataset):
    """
    Basic CAPTCHA dataset for initial preprocessing exploration.
    Loads images, removes black lines, converts to grayscale, and equalizes histogram.
    """
    def __init__(self, folder):
        self.folder = folder
        self.files = os.listdir(folder)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        filename = self.files[idx]
        img_path = os.path.join(self.folder, filename)

        # Load image
        img = cv2.imread(img_path)

        # Remove black strikethrough lines
        img, _ = remove_black_lines(img)

        # Convert to grayscale
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

        # Enhance contrast using histogram equalization
        img = cv2.equalizeHist(img)

        # Resize and pad to fixed dimensions
        img = resize_and_pad(img, target_size=(80, 280))

        # Extract label from filename (format: label-0.png)
        label = filename.split('-')[0]

        return img, label, filename

# Load dataset for visualization
dataset = CaptchaDataset("/kaggle/input/captcha-training-images/captacha_dataset/train")
loader = DataLoader(dataset, batch_size=4, shuffle=True)

# Display sample image
img = cv2.imread("/kaggle/input/captcha-training-images/captacha_dataset/train/0024miih-0.png")
print(img.shape)
plt.imshow(img)
plt.axis("off")
plt.show()


In [None]:
# Visualize preprocessing pipeline: original vs transformed images
images, labels, filenames = next(iter(loader))

plt.figure(figsize=(15, 2))
for i in range(len(images)):
    label = labels[i]
    filename = filenames[i]

    # Load original image
    ori_img = cv2.imread(f"/kaggle/input/captcha-training-images/captacha_dataset/train/{filename}")
    plt.subplot(1, 2, 1)
    if ori_img is not None:
        plt.imshow(ori_img, cmap="gray")
        plt.title(f"Original: {label}")
    else:
        plt.text(0.5, 0.5, f"Failed to load:\n{filename}", ha='center', va='center')
        plt.title("Load Error")
    plt.axis("off")

    # Display preprocessed image
    transformed_img = images[i].numpy()
    plt.subplot(1, 2, 2)
    plt.imshow(transformed_img, cmap="gray")
    plt.title(f"Preprocessed: {label}")
    plt.axis("off")

    plt.show()

### Tokenisation

In [None]:
# Define character vocabulary for CAPTCHAs
CHARS = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"

# Create character-to-index and index-to-character mappings
# Index 0 is reserved for CTC blank token
char_to_idx = {c: i + 1 for i, c in enumerate(CHARS)}
idx_to_char = {i + 1: c for i, c in enumerate(CHARS)}
idx_to_char[0] = ""  # CTC blank token

num_classes = len(CHARS) + 1
print("Total classes (incl. blank):", num_classes)

In [None]:
def encode_label(text):
    """
    Convert text string to list of token indices.
    
    Args:
        text: String to encode
    
    Returns:
        List of integer token IDs
    """
    return [char_to_idx[c] for c in text if c in char_to_idx]

def decode_label(tokens):
    """
    Convert token indices back to string.
    CTC will handle collapsing repeats and removing blanks during decoding.
    
    Args:
        tokens: List of integer token IDs
    
    Returns:
        Decoded string
    """
    return "".join(idx_to_char.get(t, "") for t in tokens)

In [None]:
# PyTorch Dataset for CTC Training with Data Augmentation

import torchvision.transforms.functional as TF
from PIL import Image
import random

class AugmentTransform:
    """
    Data augmentation pipeline for CAPTCHA images.
    Applies transformations to match real CAPTCHA distortions and improve model robustness.
    
    Transformations include:
    - Random rotation
    - Random shear/skew
    - Random translation
    - Random black diagonal lines (noise)
    - Random brightness adjustment
    - Random Gaussian noise
    """
    def __init__(self, training=True):
        self.training = training
    
    def __call__(self, img):
        """
        Apply augmentation pipeline to grayscale numpy image.
        
        Args:
            img: Grayscale numpy array
        
        Returns:
            Augmented image
        """
        if not self.training:
            return img
        
        # Convert to PIL for geometric transforms
        img_pil = Image.fromarray(img)
        
        # Random rotation (±5 degrees)
        if random.random() > 0.5:
            angle = random.uniform(-5, 5)
            img_pil = TF.rotate(img_pil, angle, fill=255)
        
        # Random shear/skew to match CAPTCHA slanted text
        if random.random() > 0.5:
            shear_x = random.uniform(-10, 10)
            img_pil = TF.affine(img_pil, angle=0, translate=(0, 0), 
                                scale=1.0, shear=shear_x, fill=255)
        
        # Random translation
        if random.random() > 0.5:
            translate = (random.randint(-3, 3), random.randint(-2, 2))
            img_pil = TF.affine(img_pil, angle=0, translate=translate, 
                                scale=1.0, shear=0, fill=255)
        
        # Convert back to numpy
        img = np.array(img_pil)
        
        # Add random black diagonal lines to simulate CAPTCHA noise (70% chance)
        if random.random() > 0.3:
            num_lines = random.randint(1, 2)
            for _ in range(num_lines):
                x1 = random.randint(0, img.shape[1])
                y1 = random.randint(0, img.shape[0])
                x2 = random.randint(0, img.shape[1])
                y2 = random.randint(0, img.shape[0])
                thickness = random.randint(1, 2)
                cv2.line(img, (x1, y1), (x2, y2), color=0, thickness=thickness)
        
        # Random brightness adjustment
        if random.random() > 0.5:
            factor = random.uniform(0.85, 1.15)
            img = np.clip(img * factor, 0, 255).astype(np.uint8)
        
        # Random Gaussian noise
        if random.random() > 0.5:
            noise = np.random.normal(0, 2, img.shape)
            img = np.clip(img + noise, 0, 255).astype(np.uint8)
        
        return img

class CaptchaCTCDataset(Dataset):
    """
    PyTorch Dataset for CAPTCHA images with CTC encoding.
    Handles preprocessing, augmentation, and label encoding.
    """
    def __init__(self, folder, augment=False, clean_files=None):
        """
        Initialize CAPTCHA dataset.
        
        Args:
            folder: Path to image folder
            augment: Whether to apply data augmentation
            clean_files: Optional list of filenames to use (for filtered training)
        """
        self.folder = folder
        
        # Use clean_files if provided, otherwise use all files
        if clean_files is not None:
            self.files = clean_files
            print(f"Using {len(self.files)} clean/filtered samples")
        else:
            self.files = [f for f in os.listdir(folder) if f.endswith('.png')]
            print(f"Using all {len(self.files)} samples")
        
        self.augment = AugmentTransform(training=augment)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        """
        Load and preprocess a single CAPTCHA image.
        
        Returns:
            img: Normalized tensor of shape (1, H, W)
            label_encoded: List of character indices
            label: Original text label
        """
        filename = self.files[idx]
        img_path = os.path.join(self.folder, filename)

        # Load image and apply preprocessing
        img = cv2.imread(img_path)
        img, _ = remove_black_lines(img)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        
        # Apply augmentation if enabled
        img = self.augment(img)
        
        # Resize and pad to fixed dimensions
        img = resize_and_pad(img, target_size=(80, 280))
        
        # Normalize to [0, 1] and convert to tensor
        img = img.astype('float32') / 255.0
        img = torch.tensor(img).unsqueeze(0)  # (1, H, W) - add channel dimension
        
        # Extract label from filename (format: label-0.png)
        label = filename.split('-')[0]
        
        # Encode label to character indices
        label_encoded = encode_label(label)
        
        return img, label_encoded, label

In [None]:
# Custom collate function for CTC - handles variable length labels
def ctc_collate_fn(batch):
    """
    Collate function for DataLoader to handle variable-length labels for CTC.
    
    Args:
        batch: List of tuples (image, label_encoded, label_text)
    
    Returns:
        images: Tensor of shape (batch_size, 1, H, W)
        targets: Concatenated label indices
        input_lengths: Length of each sequence from CNN output
        target_lengths: Length of each label
        label_texts: Original text labels (for debugging)
    """
    images, labels, label_texts = zip(*batch)
    
    # Stack images into a batch
    images = torch.stack(images, dim=0)  # (batch_size, 1, H, W)
    
    # Concatenate all labels into a single list (required by PyTorch CTC)
    targets = []
    target_lengths = []
    for label in labels:
        targets.extend(label)
        target_lengths.append(len(label))
    
    targets = torch.tensor(targets, dtype=torch.long)
    target_lengths = torch.tensor(target_lengths, dtype=torch.long)
    
    # Input lengths - depends on CNN architecture
    # For width 280 with 2 max pools (stride 2): 280 / 4 = 70 time steps
    batch_size = images.size(0)
    input_lengths = torch.full((batch_size,), 70, dtype=torch.long)  # Adjust based on CNN
    
    return images, targets, input_lengths, target_lengths, label_texts

# Create PyTorch datasets with data augmentation for training
train_dataset = CaptchaCTCDataset("/kaggle/input/captcha-training-images/captacha_dataset/train", augment=True)
test_dataset = CaptchaCTCDataset("/kaggle/input/captcha-training-images/captacha_dataset/test", augment=False)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

In [None]:
# Create PyTorch DataLoaders with custom collate function
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=ctc_collate_fn,
    num_workers=0
)

test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    collate_fn=ctc_collate_fn,
    num_workers=0
)

print(f"Number of training batches: {len(train_loader)}")
print(f"Number of test batches: {len(test_loader)}")

### Validation

In [None]:
# Validate dataset statistics
print("=== Dataset Statistics ===")
print(f"Total training samples: {len(train_dataset)}")
print(f"Total test samples: {len(test_dataset)}")

# Extract all labels from dataset
train_labels = [train_dataset[i][2] for i in range(len(train_dataset))]  # Get label text
test_labels = [test_dataset[i][2] for i in range(len(test_dataset))]

# Check label lengths
train_lengths = [len(label) for label in train_labels]
print(f"\nLabel length statistics:")
print(f"  Min label length: {min(train_lengths)}")
print(f"  Max label length: {max(train_lengths)}")
print(f"  Average label length: {sum(train_lengths)/len(train_lengths):.2f}")

# Check character distribution
all_chars = set(''.join(train_labels))
print(f"\nUnique characters in dataset: {len(all_chars)}")
print(f"Characters: {''.join(sorted(all_chars))}")

# Check if all characters are in our CHARS vocabulary
unknown_chars = all_chars - set(CHARS)
if unknown_chars:
    print(f"\nWARNING: Unknown characters found: {unknown_chars}")
else:
    print(f"\nAll characters are in the vocabulary!")

# Sample labels
print(f"\nSample labels (first 10):")
for i in range(min(10, len(train_labels))):
    print(f"  {i+1}. {train_labels[i]} (length: {len(train_labels[i])})")


In [None]:
print("=== Comprehensive Tokenization Validation ===\n")

# Test 1: Character Coverage
print("1. Testing Character Coverage:")
print(f"   CHARS vocabulary: {len(CHARS)} characters")
print(f"   Token IDs: 1 to {len(CHARS)} (0 reserved for CTC blank)")

# Test each character individually
all_pass = True
failed_chars = []
for char in CHARS:
    encoded = encode_label(char)
    decoded = decode_label(encoded)
    if decoded != char:
        all_pass = False
        failed_chars.append(char)

if all_pass:
    print(f"   All {len(CHARS)} characters encode/decode correctly!")
else:
    print(f"   Failed characters: {failed_chars}")

# Test 2: Token Range Validation
print("\n2. Testing Token Ranges:")
sample_size = min(100, len(train_dataset))
min_token = float('inf')
max_token = 0

for i in range(sample_size):
    _, encoded, _ = train_dataset[i]
    if encoded:  # if not empty
        min_token = min(min_token, min(encoded))
        max_token = max(max_token, max(encoded))

print(f"   Token range in dataset: {min_token} to {max_token}")
print(f"   Expected range: 1 to {len(CHARS)}")
if min_token >= 1 and max_token <= len(CHARS):
    print(f"   All tokens are within valid range!")
else:
    print(f"   Some tokens are out of range!")

# Test 3: Reversibility on Entire Dataset
print("\n3. Testing Reversibility (sample):")
sample_size = min(100, len(train_dataset))
mismatches = 0

for i in range(sample_size):
    _, encoded, original = train_dataset[i]
    decoded = decode_label(encoded)
    if decoded != original:
        mismatches += 1
        if mismatches <= 5:  # Show first 5 mismatches
            print(f"   Mismatch: '{original}' → {encoded} → '{decoded}'")

if mismatches == 0:
    print(f"   All {sample_size} samples encode/decode correctly!")
else:
    print(f"   Found {mismatches} mismatches out of {sample_size} samples")

# Test 4: Check for Index 0 (should not appear in encoded labels)
print("\n4. Checking for blank token (0) in labels:")
has_blank = False
for i in range(min(100, len(train_dataset))):
    _, encoded, _ = train_dataset[i]
    if 0 in encoded:
        has_blank = True
        break

if not has_blank:
    print(f"   No blank tokens (0) found in encoded labels (correct!)")
else:
    print(f"   WARNING: Blank token (0) found in labels!")

# Test 5: Token Statistics
print("\n5. Token Statistics:")
all_tokens = []
for i in range(min(1000, len(train_dataset))):
    _, encoded, _ = train_dataset[i]
    all_tokens.extend(encoded)

from collections import Counter
token_counts = Counter(all_tokens)
most_common = token_counts.most_common(5)
least_common = token_counts.most_common()[-5:]

print(f"   Total tokens analyzed: {len(all_tokens)}")
print(f"   Unique tokens used: {len(token_counts)}")
print(f"   Most common tokens:")
for token, count in most_common:
    char = idx_to_char.get(token, '?')
    print(f"      Token {token} ('{char}'): {count} times")

print("\n" + "="*50)
print("TOKENIZATION VALIDATION COMPLETE")
print("="*50)


### CRNN Model

In [None]:
"""
CTC-CRNN Model for CAPTCHA Recognition

Architecture based on:
- Shi et al. "An End-to-End Trainable Neural Network for Image-based Sequence Recognition"
- Best practices from CAPTCHA recognition research
- WandB CRNN-CTC guide and Kaggle implementations

Components:
1. CNN (VGG-style): Feature extraction from images
2. Bidirectional LSTM: Sequence modeling
3. Fully Connected: Character classification
4. CTC Loss: Alignment-free training
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


# =====================================
# ResNet Block for Better Feature Learning
# =====================================
class ResidualBlock(nn.Module):
    """
    Residual block with skip connection
    Helps with gradient flow and fine-grained feature discrimination
    """
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        
    def forward(self, x):
        identity = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        # Apply skip connection
        if self.downsample is not None:
            identity = self.downsample(x)
        
        out += identity
        out = self.relu(out)
        
        return out


class CRNN(nn.Module):
    """
    Convolutional Recurrent Neural Network with ResNet-style CNN
    
    Architecture Flow:
        Input Image (1, 80, 280)
            ↓
        ResNet CNN (6 residual blocks with skip connections)
            ↓
        Feature Maps (512, 5, 70)
            ↓
        Map to Sequence (70 time steps, 2560 features each)
            ↓
        Bidirectional LSTM (2 layers, 384 hidden)
            ↓
        Fully Connected Layer
            ↓
        Log Softmax
            ↓
        Output (70, N, 63) for CTC Loss
    
    Key improvement: ResNet skip connections for better gradient flow
    and fine-grained feature discrimination.
    """
    
    def __init__(
        self,
        img_height=80,
        img_width=280,
        num_classes=63,  # 62 alphanumeric + 1 blank
        hidden_size=256,
        num_lstm_layers=2,
        dropout=0.4
    ):
        """
        Initialize CRNN model
        
        Args:
            img_height: Input image height (default: 80)
            img_width: Input image width (default: 280)
            num_classes: Number of output classes including blank (default: 63)
            hidden_size: LSTM hidden size (default: 256)
            num_lstm_layers: Number of LSTM layers (default: 2)
            dropout: Dropout rate for LSTM (default: 0.3)
        """
        super(CRNN, self).__init__()
        
        self.img_height = img_height
        self.img_width = img_width
        self.num_classes = num_classes
        self.hidden_size = hidden_size
        
        # =====================================
        # ResNet-style CNN Feature Extractor
        # =====================================
        # Skip connections improve gradient flow and feature discrimination
        
        # Initial conv: (1, 80, 280) → (64, 80, 280)
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        # Pool1: (64, 80, 280) → (64, 40, 140)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # ResBlock layer1: (64, 40, 140) → (128, 40, 140)
        self.layer1 = self._make_layer(64, 128, blocks=2)
        
        # Pool2: (128, 40, 140) → (128, 20, 70)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # ResBlock layer2: (128, 20, 70) → (256, 20, 70)
        self.layer2 = self._make_layer(128, 256, blocks=2)
        
        # Pool3: (256, 20, 70) → (256, 10, 70)
        self.pool3 = nn.MaxPool2d(kernel_size=(2, 1))  # Only height
        
        # ResBlock layer3: (256, 10, 70) → (512, 10, 70)
        self.layer3 = self._make_layer(256, 512, blocks=2)
        
        # Pool4: (512, 10, 70) → (512, 5, 70)
        self.pool4 = nn.MaxPool2d(kernel_size=(2, 1))  # Only height
        
        # Optional dropout
        self.dropout = nn.Dropout2d(0.2)
        
        # Calculate RNN input size
        # After all conv layers: (512 channels, 5 height, 70 width)
        # We'll treat width (70) as sequence length
        # Each time step will have 512 * 5 = 2560 features
        self.map_to_seq_height = 5
        self.map_to_seq_channels = 512
        self.rnn_input_size = self.map_to_seq_height * self.map_to_seq_channels
        
        # =====================================
        # Recurrent Layers (Bidirectional LSTM)
        # =====================================
        self.rnn = nn.LSTM(
            input_size=self.rnn_input_size,
            hidden_size=hidden_size,
            num_layers=num_lstm_layers,
            bidirectional=True,
            dropout=0.3 if num_lstm_layers > 1 else 0,
            batch_first=False  # (T, N, C) format for CTC
        )
        
        # =====================================
        # Fully Connected Layer
        # =====================================
        # Map from LSTM output (hidden_size * 2 due to bidirectional) to num_classes
        self.fc = nn.Linear(hidden_size * 2, num_classes)
        
        # Initialize weights
        self._initialize_weights()
    
    def _make_layer(self, in_channels, out_channels, blocks):
        """
        Create a layer with multiple residual blocks
        
        Args:
            in_channels: Input channels
            out_channels: Output channels
            blocks: Number of residual blocks
        
        Returns:
            nn.Sequential of residual blocks
        """
        downsample = None
        if in_channels != out_channels:
            downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        
        layers = []
        layers.append(ResidualBlock(in_channels, out_channels, stride=1, downsample=downsample))
        for _ in range(1, blocks):
            layers.append(ResidualBlock(out_channels, out_channels))
        
        return nn.Sequential(*layers)
    
    def _initialize_weights(self):
        """Initialize model weights"""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        """
        Forward pass
        
        Args:
            x: Input images of shape (N, 1, H, W)
               N = batch size
               H = height (80)
               W = width (280)
        
        Returns:
            log_probs: Log probabilities of shape (T, N, C)
                      T = sequence length (70)
                      N = batch size
                      C = num_classes (63)
        """
        # ===== 1. ResNet CNN Feature Extraction =====
        x = self.conv1(x)      # (N, 64, 80, 280)
        x = self.pool1(x)      # (N, 64, 40, 140)
        
        x = self.layer1(x)     # (N, 128, 40, 140) - ResBlock
        x = self.pool2(x)      # (N, 128, 20, 70)
        
        x = self.layer2(x)     # (N, 256, 20, 70) - ResBlock
        x = self.pool3(x)      # (N, 256, 10, 70)
        
        x = self.layer3(x)     # (N, 512, 10, 70) - ResBlock
        x = self.pool4(x)      # (N, 512, 5, 70)
        
        conv_out = self.dropout(x)  # (N, 512, 5, 70)
        
        batch_size, channels, height, width = conv_out.size()
        
        # ===== 2. Map to Sequence =====
        # Reshape CNN output to sequence format
        # (N, C, H, W) → (N, W, C*H)
        # Width becomes sequence length, C*H becomes features per time step
        conv_out = conv_out.permute(0, 3, 1, 2)  # (N, 70, 512, 5)
        conv_out = conv_out.reshape(batch_size, width, channels * height)  # (N, 70, 2560)
        
        # ===== 3. Prepare for LSTM =====
        # LSTM expects (T, N, features) when batch_first=False
        rnn_input = conv_out.permute(1, 0, 2)  # (70, N, 2560)
        
        # ===== 4. Bidirectional LSTM =====
        rnn_output, _ = self.rnn(rnn_input)  # (70, N, 512)
        # Output size: hidden_size * 2 (bidirectional) = 256 * 2 = 512
        
        # ===== 5. Fully Connected Layer =====
        T, N, hidden = rnn_output.size()
        
        # Reshape to apply FC layer
        rnn_output = rnn_output.reshape(T * N, hidden)  # (70*N, 512)
        output = self.fc(rnn_output)  # (70*N, 63)
        output = output.reshape(T, N, self.num_classes)  # (70, N, 63)
        
        # ===== 6. Log Softmax for CTC Loss =====
        log_probs = F.log_softmax(output, dim=2)  # (70, N, 63)
        
        return log_probs
    
    def get_sequence_length(self):
        """
        Get the output sequence length after CNN layers
        
        Returns:
            Sequence length (70 for input width 280)
        """
        # Width reduction through pooling:
        # Original: 280
        # After MaxPool 2x2 (stride 2): 140
        # After MaxPool 2x2 (stride 2): 70
        # After MaxPool (2,1): 70 (no width change)
        # After MaxPool (2,1): 70 (no width change)
        return 70
    
    def count_parameters(self):
        """Count total and trainable parameters"""
        total = sum(p.numel() for p in self.parameters())
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        return total, trainable


# =====================================
# Helper Functions
# =====================================

def create_model(num_classes=63, hidden_size=256, device='cpu'):
    """
    Create and initialize CRNN model
    
    Args:
        num_classes: Number of output classes (default: 63)
        hidden_size: LSTM hidden size (default: 256)
        device: Device to put model on (default: 'cpu')
    
    Returns:
        model: Initialized CRNN model
    """
    model = CRNN(
        img_height=80,
        img_width=280,
        num_classes=num_classes,
        hidden_size=hidden_size,
        num_lstm_layers=2,
        dropout=0.5
    )
    
    model = model.to(device)
    
    # Print model summary
    total_params, trainable_params = model.count_parameters()
    print(f"Model created successfully!")
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Device: {device}")
    
    return model


def test_model():
    """Test model with dummy input"""
    print("="*60)
    print("Testing CRNN Model")
    print("="*60)
    
    # Create model
    model = CRNN(img_height=80, img_width=280, num_classes=63)
    model.eval()
    
    # Create dummy input
    batch_size = 4
    dummy_input = torch.randn(batch_size, 1, 80, 280)
    
    print(f"\nInput shape: {dummy_input.shape}")
    
    # Forward pass
    with torch.no_grad():
        output = model(dummy_input)
    
    print(f"Output shape: {output.shape}")
    print(f"Expected: (70, {batch_size}, 63)")
    
    # Check output dimensions
    T, N, C = output.shape
    assert T == 70, f"Sequence length should be 70, got {T}"
    assert N == batch_size, f"Batch size should be {batch_size}, got {N}"
    assert C == 63, f"Num classes should be 63, got {C}"
    
    print("\nModel test passed!")
    
    # Count parameters
    total, trainable = model.count_parameters()
    print(f"\nModel Statistics:")
    print(f"  Total parameters: {total:,}")
    print(f"  Trainable parameters: {trainable:,}")
    print(f"  Model size: ~{total * 4 / 1024 / 1024:.2f} MB (float32)")
    
    print("="*60)


if __name__ == "__main__":
    test_model()



### Training utils

In [None]:
"""
Training Utilities for CTC-CRNN CAPTCHA Recognition

Includes:
- Training loop
- Validation loop
- CTC greedy decoder
- Learning rate scheduling
- Checkpoint management
"""

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import os
from datetime import datetime


# =====================================
# CTC Loss
# =====================================

def get_ctc_loss(blank=0, reduction='mean', zero_infinity=True):
    """
    Create CTC loss function
    
    Args:
        blank: Index of blank token (default: 0)
        reduction: Loss reduction method (default: 'mean')
        zero_infinity: Whether to zero infinite losses (default: True)
    
    Returns:
        CTC loss criterion
    """
    return nn.CTCLoss(blank=blank, reduction=reduction, zero_infinity=zero_infinity)


# =====================================
# CTC Decoders
# =====================================

def ctc_decode_greedy(log_probs, input_lengths, idx_to_char):
    """
    Greedy CTC decoder - simple and fast
    
    Args:
        log_probs: Log probabilities from model (T, N, C)
        input_lengths: Actual sequence lengths (N,)
        idx_to_char: Dictionary mapping token IDs to characters
    
    Returns:
        List of decoded strings
    """
    # Get most probable class at each time step
    _, max_indices = torch.max(log_probs, dim=2)  # (T, N)
    max_indices = max_indices.transpose(0, 1)  # (N, T)
    
    decoded_texts = []
    for i, length in enumerate(input_lengths):
        # Get predictions for this sequence
        pred_tokens = max_indices[i, :length].tolist()
        
        # CTC collapse: remove consecutive duplicates and blanks
        collapsed = []
        prev_token = None
        for token in pred_tokens:
            if token != 0 and token != prev_token:  # 0 is blank
                collapsed.append(token)
            prev_token = token
        
        # Decode to text
        text = ''.join([idx_to_char.get(t, '') for t in collapsed])
        decoded_texts.append(text)
    
    return decoded_texts


# =====================================
# Training Functions
# =====================================

def train_one_epoch(model, dataloader, criterion, optimizer, device, idx_to_char=None):
    """
    Train model for one epoch
    
    Args:
        model: CRNN model
        dataloader: Training data loader
        criterion: CTC loss function
        optimizer: Optimizer
        device: Device (cuda/cpu)
        idx_to_char: Character mapping for decoding (optional)
    
    Returns:
        Average loss for the epoch
    """
    model.train()
    total_loss = 0
    num_batches = len(dataloader)
    
    with tqdm(dataloader, desc="Training", unit="batch") as pbar:
        for batch_idx, (images, targets, input_lengths, target_lengths, _) in enumerate(pbar):
            # Move to device
            images = images.to(device)
            targets = targets.to(device)
            input_lengths = input_lengths.to(device)
            target_lengths = target_lengths.to(device)
            
            # Forward pass
            log_probs = model(images)  # (T, N, C)
            
            # Calculate CTC loss
            loss = criterion(log_probs, targets, input_lengths, target_lengths)
            
            # Check for invalid loss
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"\nWarning: Invalid loss at batch {batch_idx}: {loss.item()}")
                continue
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping (important for LSTM)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            
            optimizer.step()
            
            # Update progress bar
            total_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    avg_loss = total_loss / num_batches
    return avg_loss


def validate(model, dataloader, criterion, device, idx_to_char):
    """
    Validate model using greedy decoding
    
    Args:
        model: CRNN model
        dataloader: Validation data loader
        criterion: CTC loss function
        device: Device (cuda/cpu)
        idx_to_char: Character mapping for decoding
    
    Returns:
        avg_loss: Average validation loss
        accuracy: Character-level accuracy
        word_accuracy: Word-level accuracy
    """
    model.eval()
    total_loss = 0
    correct_chars = 0
    total_chars = 0
    correct_words = 0
    total_words = 0
    
    with torch.no_grad():
        with tqdm(dataloader, desc="Validation", unit="batch") as pbar:
            for images, targets, input_lengths, target_lengths, label_texts in pbar:
                # Move to device
                images = images.to(device)
                targets = targets.to(device)
                input_lengths = input_lengths.to(device)
                target_lengths = target_lengths.to(device)
                
                # Forward pass
                log_probs = model(images)  # (T, N, C)
                
                # Calculate loss
                loss = criterion(log_probs, targets, input_lengths, target_lengths)
                
                if not (torch.isnan(loss) or torch.isinf(loss)):
                    total_loss += loss.item()
                
                # Decode predictions using greedy decoder
                predictions = ctc_decode_greedy(log_probs, input_lengths, idx_to_char)
                
                # Calculate accuracy
                for pred, gt in zip(predictions, label_texts):
                    # Word accuracy
                    if pred == gt:
                        correct_words += 1
                    total_words += 1
                    
                    # Character accuracy
                    for p, g in zip(pred, gt):
                        if p == g:
                            correct_chars += 1
                    total_chars += len(gt)
                
                # Update progress bar
                pbar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'acc': f'{correct_words/total_words:.4f}'
                })
    
    avg_loss = total_loss / len(dataloader)
    char_accuracy = correct_chars / total_chars if total_chars > 0 else 0
    word_accuracy = correct_words / total_words if total_words > 0 else 0
    
    return avg_loss, char_accuracy, word_accuracy


# =====================================
# Checkpoint Management
# =====================================

def save_checkpoint(model, optimizer, epoch, loss, accuracy, filepath):
    """
    Save model checkpoint
    
    Args:
        model: Model to save
        optimizer: Optimizer state
        epoch: Current epoch
        loss: Current loss
        accuracy: Current accuracy
        filepath: Path to save checkpoint
    """
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'accuracy': accuracy,
        'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    }
    
    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved: {filepath}")


def load_checkpoint(model, optimizer, filepath, device):
    """
    Load model checkpoint
    
    Args:
        model: Model to load weights into
        optimizer: Optimizer to load state into
        filepath: Path to checkpoint
        device: Device to load on
    
    Returns:
        epoch: Last epoch number
        loss: Last loss value
        accuracy: Last accuracy value
    """
    checkpoint = torch.load(filepath, map_location=device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    accuracy = checkpoint.get('accuracy', 0.0)
    
    print(f"Checkpoint loaded: {filepath}")
    print(f"Epoch: {epoch}, Loss: {loss:.4f}, Accuracy: {accuracy:.4f}")
    
    return epoch, loss, accuracy


# =====================================
# Full Training Loop
# =====================================

def train_model(
    model,
    train_loader,
    val_loader,
    idx_to_char,
    num_epochs=50,
    learning_rate=0.001,
    device='cuda',
    checkpoint_dir='checkpoints',
    save_every=5
):
    """
    Complete training loop
    
    Args:
        model: CRNN model
        train_loader: Training data loader
        val_loader: Validation data loader
        idx_to_char: Character mapping
        num_epochs: Number of epochs to train
        learning_rate: Initial learning rate
        device: Device (cuda/cpu)
        checkpoint_dir: Directory to save checkpoints
        save_every: Save checkpoint every N epochs
    
    Returns:
        Training history dictionary
    """
    # Create checkpoint directory
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Setup
    model = model.to(device)
    criterion = get_ctc_loss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3, verbose=True
    )
    
    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_char_acc': [],
        'val_word_acc': [],
        'lr': []
    }
    
    best_acc = 0.0
    
    print(f"\nStarting training for {num_epochs} epochs")
    print(f"Device: {device}")
    print("="*60)
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        # Train
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device, idx_to_char)
        
        # Validate
        val_loss, char_acc, word_acc = validate(model, val_loader, criterion, device, idx_to_char)
        
        # Update learning rate
        scheduler.step(val_loss)
        current_lr = optimizer.param_groups[0]['lr']
        
        # Store history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_char_acc'].append(char_acc)
        history['val_word_acc'].append(word_acc)
        history['lr'].append(current_lr)
        
        # Print epoch summary
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val Loss: {val_loss:.4f}")
        print(f"  Char Accuracy: {char_acc:.4f}")
        print(f"  Word Accuracy: {word_acc:.4f}")
        print(f"  Learning Rate: {current_lr}")
        
        # Save checkpoint
        if (epoch + 1) % save_every == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth')
            save_checkpoint(model, optimizer, epoch+1, val_loss, word_acc, checkpoint_path)
        
        # Save best model
        if word_acc > best_acc:
            best_acc = word_acc
            best_path = os.path.join(checkpoint_dir, 'best_model.pth')
            save_checkpoint(model, optimizer, epoch+1, val_loss, word_acc, best_path)
            print(f"  New best model! Accuracy: {best_acc:.4f}")
    
    print("\n" + "="*60)
    print("Training completed!")
    print(f"Best word accuracy: {best_acc:.4f}")
    
    return history


if __name__ == "__main__":
    print("Training utilities loaded successfully!")
    print("\nAvailable functions:")
    print("  - train_one_epoch()")
    print("  - validate()")
    print("  - ctc_decode_greedy()")
    print("  - save_checkpoint()")
    print("  - load_checkpoint()")
    print("  - train_model() [Full training loop]")



### Training script

In [None]:
# Training Configuration and Setup

# Set device (auto-detects GPU/MPS/CPU)
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using device: mps (Apple Silicon GPU)")
elif torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"Using device: cuda ({torch.cuda.get_device_name(0)})")
else:
    device = torch.device('cpu')
    print("Using device: cpu")

# Dataset paths
TRAIN_DIR = '/kaggle/input/captcha-training-images/captacha_dataset/train'
TEST_DIR = '/kaggle/input/captcha-training-images/captacha_dataset/test'

# Training hyperparameters
BATCH_SIZE = 64
NUM_EPOCHS = 150
LEARNING_RATE = 0.001
HIDDEN_SIZE = 384
NUM_LSTM_LAYERS = 2
NUM_WORKERS = 2
SAVE_EVERY = 5  # Save checkpoint every N epochs

# Create checkpoint directory
os.makedirs('checkpoints', exist_ok=True)

# =====================================
# Create Datasets & DataLoaders
# =====================================

print("\nLoading datasets...")

# Create datasets with augmentation for training
train_dataset = CaptchaCTCDataset(TRAIN_DIR, augment=True)
test_dataset = CaptchaCTCDataset(TEST_DIR, augment=False)

print(f"Test samples: {len(test_dataset)}")

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=ctc_collate_fn,
    num_workers=NUM_WORKERS,
    pin_memory=True if device.type == 'cuda' else False
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=ctc_collate_fn,
    num_workers=NUM_WORKERS,
    pin_memory=True if device.type == 'cuda' else False
)

print(f"Train batches: {len(train_loader)}")
print(f"Test batches: {len(test_loader)}")

# =====================================
# Create Model
# =====================================

print("\nCreating model...")
model = CRNN(
    img_height=80,
    img_width=280,
    num_classes=63,
    hidden_size=HIDDEN_SIZE,
    num_lstm_layers=NUM_LSTM_LAYERS,
    dropout=0.3
)
model = model.to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Model created!")
print(f"  Total parameters: {total_params:,}")
print(f"  Model size: ~{total_params * 4 / 1024 / 1024:.1f} MB")

# =====================================
# Training Setup
# =====================================

# CTC Loss with Label Smoothing
class CTCLossWithLabelSmoothing(nn.Module):
    """
    CTC Loss with label smoothing regularization.
    Prevents overconfident predictions and improves generalization.
    """
    def __init__(self, blank=0, smoothing=0.1):
        super().__init__()
        self.ctc_loss = nn.CTCLoss(blank=blank, reduction='mean', zero_infinity=True)
        self.smoothing = smoothing
        self.blank = blank
    
    def forward(self, log_probs, targets, input_lengths, target_lengths):
        # Compute standard CTC loss
        loss = self.ctc_loss(log_probs, targets, input_lengths, target_lengths)
        
        # Apply label smoothing if configured
        if self.smoothing > 0:
            smoothed_loss = -(log_probs.mean())
            loss = (1 - self.smoothing) * loss + self.smoothing * smoothed_loss
        
        return loss

criterion = CTCLossWithLabelSmoothing(blank=0, smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=5e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, verbose=True
)

# Learning rate warmup configuration
WARMUP_EPOCHS = 5
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
    optimizer, 
    start_factor=0.1,  # Start at 10% of base learning rate
    total_iters=WARMUP_EPOCHS
)

# Initialize training history tracking
history = {
    'train_loss': [],
    'val_loss': [],
    'val_acc': [],
    'char_acc': []
}

best_acc = 0.0

# Early stopping configuration
EARLY_STOP_PATIENCE = 10
early_stop_counter = 0
best_val_loss = float('inf')

print(f"\n{'='*70}")
print(f"Starting Training - {NUM_EPOCHS} Epochs")
print(f"  Early Stopping: Patience = {EARLY_STOP_PATIENCE} epochs")
print(f"{'='*70}\n")

# =====================================
# Training Loop
# =====================================

for epoch in range(NUM_EPOCHS):
    print(f"\n{'='*70}")
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    print(f"{'='*70}")
    
    # ===== Train =====
    model.train()
    total_loss = 0
    
    with tqdm(train_loader, desc="Training", unit="batch") as pbar:
        for batch_idx, (images, targets, input_lengths, target_lengths, _) in enumerate(pbar):
            # Move to device
            images = images.to(device)
            targets = targets.to(device)
            input_lengths = input_lengths.to(device)
            target_lengths = target_lengths.to(device)
            
            # Forward pass
            log_probs = model(images)
            loss = criterion(log_probs, targets, input_lengths, target_lengths)
            
            # Skip if loss is invalid
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"\nWarning: Invalid loss at batch {batch_idx}")
                continue
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            optimizer.step()
            
            total_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    train_loss = total_loss / len(train_loader)
    
    # ===== Validate =====
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    correct_chars = 0
    total_chars = 0
    
    with torch.no_grad():
        with tqdm(test_loader, desc="Validation", unit="batch") as pbar:
            for images, targets, input_lengths, target_lengths, label_texts in pbar:
                images = images.to(device)
                targets = targets.to(device)
                input_lengths = input_lengths.to(device)
                target_lengths = target_lengths.to(device)
                
                log_probs = model(images)
                loss = criterion(log_probs, targets, input_lengths, target_lengths)
                
                if not (torch.isnan(loss) or torch.isinf(loss)):
                    val_loss += loss.item()
                
                # Decode predictions using greedy decoder
                _, max_indices = torch.max(log_probs, dim=2)
                max_indices = max_indices.transpose(0, 1)
                
                predictions = []
                for i, length in enumerate(input_lengths):
                    pred_tokens = max_indices[i, :length].tolist()
                    collapsed = []
                    prev = None
                    for token in pred_tokens:
                        if token != 0 and token != prev:
                            collapsed.append(token)
                        prev = token
                    pred_text = ''.join([idx_to_char.get(t, '') for t in collapsed])
                    predictions.append(pred_text)
                
                for pred, gt in zip(predictions, label_texts):
                    # Compute sequence-level accuracy (entire string must match)
                    if pred == gt:
                        correct += 1
                    total += 1
                    
                    # Compute character-level accuracy
                    for p, g in zip(pred, gt):
                        if p == g:
                            correct_chars += 1
                    total_chars += len(gt)
                
                pbar.set_postfix({
                    'seq_acc': f'{correct/total:.4f}',
                    'char_acc': f'{correct_chars/total_chars:.4f}' if total_chars > 0 else '0.0000'
                })
    
    val_loss = val_loss / len(test_loader)
    val_acc = correct / total
    char_acc = correct_chars / total_chars if total_chars > 0 else 0
    
    # ===== Update Learning Rate =====
    # Use warmup for first epochs, then switch to ReduceLROnPlateau
    if epoch < WARMUP_EPOCHS:
        warmup_scheduler.step()
    else:
        scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    
    # ===== Store History =====
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['char_acc'].append(char_acc)
    
    # ===== Early Stopping Check =====
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stop_counter = 0
    else:
        early_stop_counter += 1
    
    # ===== Print Summary =====
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"   Train Loss: {train_loss:.4f}")
    print(f"   Val Loss:   {val_loss:.4f}")
    print(f"   Val Acc (Seq): {val_acc:.4f}")
    print(f"   Val Acc (Char): {char_acc:.4f}")
    print(f"   LR:         {current_lr:.6f}")
    print(f"   Early Stop: {early_stop_counter}/{EARLY_STOP_PATIENCE}")
    
    # ===== Save Checkpoint =====
    if (epoch + 1) % SAVE_EVERY == 0:
        checkpoint_path = f'checkpoints/checkpoint_epoch_{epoch+1}.pth'
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_acc': val_acc,
            'history': history
        }, checkpoint_path)
        print(f"   Checkpoint saved: {checkpoint_path}")
    
    # ===== Save Best Model =====
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc
        }, 'checkpoints/best_model.pth')
        print(f"   New best model! Accuracy: {best_acc:.4f}")
    
    # ===== Early Stopping Trigger =====
    if early_stop_counter >= EARLY_STOP_PATIENCE:
        print(f"\n{'='*70}")
        print(f"Early stopping triggered! No improvement for {EARLY_STOP_PATIENCE} epochs.")
        print(f"{'='*70}")
        break

print(f"\n{'='*70}")
print(f"Training Complete!")
print(f"{'='*70}")
print(f"Best Accuracy (Seq): {best_acc:.4f}")
print(f"Final Train Loss: {history['train_loss'][-1]:.4f}")
print(f"Final Val Loss: {history['val_loss'][-1]:.4f}")
print(f"Final Val Acc (Seq): {history['val_acc'][-1]:.4f}")
print(f"Final Val Acc (Char): {history['char_acc'][-1]:.4f}")



### Results

In [None]:
import matplotlib.pyplot as plt

# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss plot
axes[0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Accuracy plot
axes[1].plot(history['val_acc'], label='Validation Accuracy', color='green', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy', fontsize=12)
axes[1].set_title('Validation Accuracy', fontsize=14, fontweight='bold')
axes[1].set_ylim([0, 1])
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_history.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nTraining Summary:")
print(f"   Best Accuracy: {max(history['val_acc']):.4f}")
print(f"   Final Train Loss: {history['train_loss'][-1]:.4f}")
print(f"   Final Val Loss: {history['val_loss'][-1]:.4f}")
print(f"   Final Val Acc: {history['val_acc'][-1]:.4f}")
print(f"\nPlot saved as 'training_history.png'")

# Package checkpoints for download (Kaggle/Colab)
try:
    import subprocess
    subprocess.run(['zip', '-r', 'captcha_checkpoints.zip', 'checkpoints/'], check=True)
    print("Checkpoints packaged as 'captcha_checkpoints.zip'")
    print("  Download from Output tab (Kaggle) or Files panel (Colab)")
except Exception as e:
    print(f"Note: Manual zip - run: !zip -r captcha_checkpoints.zip checkpoints/")



### Error analysis

In [None]:
# Analyze prediction errors to understand what's failing
# Run this AFTER training completes to get insights for next steps

def analyze_errors(model, dataloader, device, idx_to_char, num_samples=200):
    """
    Analyze common prediction errors to identify patterns
    
    Args:
        model: Trained CRNN model
        dataloader: Test data loader
        device: Device (cuda/cpu)
        idx_to_char: Character mapping
        num_samples: Number of error samples to collect
    
    Returns:
        errors: List of error dictionaries
    """
    model.eval()
    errors = []
    correct_samples = []
    
    with torch.no_grad():
        for images, targets, input_lengths, target_lengths, label_texts in dataloader:
            if len(errors) >= num_samples:
                break
                
            images = images.to(device)
            log_probs = model(images)
            
            # Decode (greedy)
            _, max_indices = torch.max(log_probs, dim=2)
            max_indices = max_indices.transpose(0, 1)
            
            for i in range(len(label_texts)):
                pred_tokens = max_indices[i, :70].tolist()
                collapsed = []
                prev = None
                for token in pred_tokens:
                    if token != 0 and token != prev:
                        collapsed.append(token)
                    prev = token
                
                pred_text = ''.join([idx_to_char.get(t, '') for t in collapsed])
                gt_text = label_texts[i]
                
                if pred_text != gt_text:
                    # Calculate edit distance
                    import difflib
                    errors.append({
                        'predicted': pred_text,
                        'ground_truth': gt_text,
                        'len_diff': len(pred_text) - len(gt_text),
                        'len_pred': len(pred_text),
                        'len_gt': len(gt_text)
                    })
                else:
                    correct_samples.append(gt_text)
    
    # Analyze error patterns
    from collections import Counter
    print(f"\n{'='*70}")
    print(f"Error Analysis - {len(errors)} errors collected")
    print(f"{'='*70}")
    
    # 1. Length Analysis
    len_diffs = [e['len_diff'] for e in errors]
    print(f"\nLength Errors:")
    too_short = sum(1 for d in len_diffs if d < 0)
    too_long = sum(1 for d in len_diffs if d > 0)
    correct_len = sum(1 for d in len_diffs if d == 0)
    
    print(f"  Too short (missing chars):   {too_short:3d} ({too_short/len(errors)*100:5.1f}%)")
    print(f"  Too long (extra chars):      {too_long:3d} ({too_long/len(errors)*100:5.1f}%)")
    print(f"  Correct length (wrong char): {correct_len:3d} ({correct_len/len(errors)*100:5.1f}%)")
    
    # 2. Character Confusion Matrix
    char_errors = []
    positional_errors = {'start': 0, 'middle': 0, 'end': 0}
    
    for e in errors:
        pred, gt = e['predicted'], e['ground_truth']
        min_len = min(len(pred), len(gt))
        
        for i in range(min_len):
            if pred[i] != gt[i]:
                char_errors.append((gt[i], pred[i]))
                
                # Track position
                if i < 2:
                    positional_errors['start'] += 1
                elif i >= len(gt) - 2:
                    positional_errors['end'] += 1
                else:
                    positional_errors['middle'] += 1
    
    print(f"\nTop 15 Character Confusions:")
    confusion_counts = Counter(char_errors)
    for idx, ((gt_char, pred_char), count) in enumerate(confusion_counts.most_common(15), 1):
        print(f"  {idx:2d}. '{gt_char}' → '{pred_char}': {count:3d} times")
    
    # 3. Positional error analysis
    total_pos_errors = sum(positional_errors.values())
    if total_pos_errors > 0:
        print(f"\nError Position Distribution:")
        print(f"  Start (0-1):   {positional_errors['start']:3d} ({positional_errors['start']/total_pos_errors*100:5.1f}%)")
        print(f"  Middle (2-n):  {positional_errors['middle']:3d} ({positional_errors['middle']/total_pos_errors*100:5.1f}%)")
        print(f"  End (last 2):  {positional_errors['end']:3d} ({positional_errors['end']/total_pos_errors*100:5.1f}%)")
    
    # 4. Length distribution
    len_gt_dist = Counter([e['len_gt'] for e in errors])
    print(f"\nError Distribution by Ground Truth Length:")
    for length in sorted(len_gt_dist.keys()):
        count = len_gt_dist[length]
        print(f"  Length {length}: {count:3d} errors ({count/len(errors)*100:5.1f}%)")
    
    # 5. Sample errors (most informative)
    print(f"\nSample Errors (showing first 20):")
    print(f"{'':4s}{'Ground Truth':<20s} → {'Predicted':<20s} {'Length'}")
    print(f"    {'-'*20}   {'-'*20} {'-'*8}")
    
    for i, e in enumerate(errors[:20], 1):
        gt_display = e['ground_truth'][:18] if len(e['ground_truth']) > 18 else e['ground_truth']
        pred_display = e['predicted'][:18] if len(e['predicted']) > 18 else e['predicted']
        
        len_indicator = ""
        if e['len_diff'] < 0:
            len_indicator = f"(-{abs(e['len_diff'])})"
        elif e['len_diff'] > 0:
            len_indicator = f"(+{e['len_diff']})"
        else:
            len_indicator = "(same)"
        
        print(f"  {i:2d}. {gt_display:<20s} → {pred_display:<20s} {len_indicator}")
    
    # 6. Key Insights
    print(f"\nKey Insights:")
    
    if correct_len / len(errors) > 0.5:
        print(f"  - 50%+ errors have correct length → Focus on character recognition")
    else:
        print(f"  - <50% errors have correct length → Sequence modeling issue")
    
    if too_short > too_long * 1.5:
        print(f"  - Model tends to predict too short → CTC blanks collapsing too much")
    elif too_long > too_short * 1.5:
        print(f"  - Model tends to predict too long → False positives in character detection")
    
    # Check for common confusions
    if confusion_counts:
        top_confusion = confusion_counts.most_common(1)[0]
        if top_confusion[1] > len(errors) * 0.1:  # More than 10% of errors
            print(f"  - High confusion: '{top_confusion[0][0]}' ↔ '{top_confusion[0][1]}' ({top_confusion[1]} times)")
    
    print(f"\n{'='*70}")
    
    return errors


# Run error analysis on test set
print("\n" + "="*70)
print("Running Error Analysis...")
print("="*70)

errors = analyze_errors(model, test_loader, device, idx_to_char, num_samples=200)

print("\nError analysis complete!")
print(f"  Next steps will depend on the patterns observed above.")
