In [14]:
import os
import random
import nltk
from PIL import Image, ImageDraw, ImageFont

# Ensure nltk resources are available
nltk.download('words')
from nltk.corpus import words

# Parameters
output_dir = "ocr_dataset"
image_size = (256, 64)
font_size = 32
num_samples = 100000

# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)

# Fetch word list and sample 100k words
word_list = words.words()
selected_words = random.sample(word_list, num_samples)

# Load a default font
try:
    font = ImageFont.truetype("arial.ttf", font_size)
except IOError:
    font = ImageFont.load_default()

# Generate images
for word in selected_words:
    # Sanitize word to ensure valid filenames
    sanitized_word = word.replace("/", "").replace("\\", "")

    # Create a blank image with white background
    img = Image.new("RGB", image_size, "white")
    draw = ImageDraw.Draw(img)

    # Get text size and position
    text_bbox = draw.textbbox((0, 0), sanitized_word, font=font)  # Bounding box for the text
    text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
    position = ((image_size[0] - text_width) // 2, (image_size[1] - text_height) // 2)

    # Render word onto image
    draw.text(position, sanitized_word, fill="black", font=font)

    # Save image
    filename = f"{sanitized_word}.png"
    img.save(os.path.join(output_dir, filename))

print(f"Dataset generation complete. Images saved to {output_dir}")

[nltk_data] Downloading package words to /root/nltk_data...
[nltk_data]   Package words is already up-to-date!


Dataset generation complete. Images saved to ocr_dataset


In [15]:
import os
import shutil
import random

# Paths
ocr_dataset_dir = "ocr_dataset"
train_dir = os.path.join(ocr_dataset_dir, "train")
val_dir = os.path.join(ocr_dataset_dir, "val")
test_dir = os.path.join(ocr_dataset_dir, "test")

# Clean existing split directories if they exist
for dir_path in [train_dir, val_dir, test_dir]:
    if os.path.exists(dir_path):
        shutil.rmtree(dir_path)
    os.makedirs(dir_path)

# Get list of all image files
image_filenames = [f for f in os.listdir(ocr_dataset_dir) if f.endswith(".png")]

# Shuffle the filenames randomly
random.shuffle(image_filenames)

# Calculate split indices
total_files = len(image_filenames)
train_split = int(0.8 * total_files)
val_split = int(0.9 * total_files)

# Split files
train_files = image_filenames[:train_split]
val_files = image_filenames[train_split:val_split]
test_files = image_filenames[val_split:]

# Move files to respective folders
for files, dest_dir in zip([train_files, val_files, test_files], [train_dir, val_dir, test_dir]):
    for file in files:
        src_path = os.path.join(ocr_dataset_dir, file)
        if os.path.exists(src_path):  # Check if file exists before moving
            shutil.copy2(src_path, os.path.join(dest_dir, file))  # Using copy2 instead of move

# Print statistics
print(f"Total files: {total_files}")
print(f"Training files: {len(train_files)}")
print(f"Validation files: {len(val_files)}")
print(f"Test files: {len(test_files)}")

# Print sample of files in each split to verify distribution
print("\nSample of training files:", sorted(os.listdir(train_dir))[:5])
print("Sample of validation files:", sorted(os.listdir(val_dir))[:5])
print("Sample of test files:", sorted(os.listdir(test_dir))[:5])

Total files: 99846
Training files: 79876
Validation files: 9985
Test files: 9985

Sample of training files: ['A.png', 'Aani.png', 'Aaronical.png', 'Aaronite.png', 'Aaru.png']
Sample of validation files: ['Abadite.png', 'Abassin.png', 'Abrus.png', 'Absi.png', 'Abuta.png']
Sample of test files: ['Ababdeh.png', 'Abelicea.png', 'Abipon.png', 'Acanthodidae.png', 'Aceraceae.png']


In [19]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torchvision import transforms

In [18]:
class OCRDataset(Dataset):
    def __init__(self, image_dir, char_to_idx, transform=None, max_length=30):
        self.image_dir = image_dir
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith(".png")]
        self.char_to_idx = char_to_idx
        self.transform = transform
        self.max_length = max_length

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_name = self.image_files[idx]
        # Extract label from filename (remove .png extension)
        label = os.path.splitext(image_name)[0]

        # Encode the label using char_to_idx
        label_encoded = [self.char_to_idx[char] for char in label]

        # Load image
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        # Convert label to tensor with padding
        label_tensor = torch.full((self.max_length,), -1, dtype=torch.long)  # Fill with padding token (-1)
        label_tensor[:len(label_encoded)] = torch.tensor(label_encoded)

        return image, label_tensor

# Complete character set including upper and lowercase letters, numbers, and common special characters
char_set = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 '-"
char_to_idx = {char: idx for idx, char in enumerate(char_set)}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}


In [22]:
# Create datasets with transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
batch_size = 32
learning_rate = 0.001
num_epochs = 10

# Create datasets
train_dataset = OCRDataset("ocr_dataset/train", char_to_idx, transform=transform, max_length=30)
val_dataset = OCRDataset("ocr_dataset/val", char_to_idx, transform=transform, max_length=30)
test_dataset = OCRDataset("ocr_dataset/test", char_to_idx, transform=transform, max_length=30)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)





In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import sys

In [3]:
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../../models/RNN')))
from crnn import OCRModel

In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class OCRModel(nn.Module):
    def __init__(self, num_classes, hidden_dim=256):
        super(OCRModel, self).__init__()
        # CNN Encoder
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        # RNN Decoder
        self.rnn = nn.GRU(
            input_size=64 * (256 // 4),  # Feature vector size from CNN
            hidden_size=hidden_dim,
            num_layers=2,
            batch_first=True,
        )

        # Output layer
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        # Pass through CNN
        batch_size = x.size(0)
        features = self.cnn(x)  # Output shape: [B, 64, H/4, W/4]
        features = features.permute(0, 2, 3, 1)  # Shape: [B, H/4, W/4, 64]
        features = features.reshape(batch_size, -1, 64 * (256 // 4))  # Shape: [B, Seq_len, Feature_dim]

        # Pass through RNN
        rnn_out, _ = self.rnn(features)  # Shape: [B, Seq_len, Hidden_dim]
        output = self.fc(rnn_out)  # Shape: [B, Seq_len, Num_classes]

        return output

In [23]:
# Model, loss, optimizer
# Update model to match new vocabulary size
model = OCRModel(num_classes=len(char_to_idx))
criterion = nn.CrossEntropyLoss(ignore_index=-1)  # Ignore padding tokens if applicable
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

OCRModel(
  (cnn): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (rnn): GRU(4096, 256, num_layers=2, batch_first=True)
  (fc): Linear(in_features=256, out_features=65, bias=True)
)

In [27]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0

    for batch_idx, (images, labels) in enumerate(train_loader):
        images = images.to(device)  # [batch_size, channels, height, width]
        labels = labels.to(device)  # [batch_size, max_length]

        # Forward pass
        outputs = model(images)  # [batch_size, seq_len, num_classes]

        # Get the actual batch size and sequence length
        batch_size, seq_len, num_classes = outputs.shape

        # Ensure labels match the sequence length of outputs
        labels = labels[:, :seq_len]

        # Calculate loss (keep batch dimension)
        loss = 0
        for i in range(seq_len):
            # Get predictions and labels for current position
            curr_output = outputs[:, i, :]  # [batch_size, num_classes]
            curr_labels = labels[:, i]      # [batch_size]

            # Only calculate loss for non-padding positions
            mask = curr_labels != -1
            if mask.any():
                curr_loss = criterion(curr_output[mask], curr_labels[mask])
                loss += curr_loss

        loss = loss / seq_len  # Average over sequence length

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)

def validate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct_chars = 0
    total_chars = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)  # [batch_size, seq_len, num_classes]

            # Get the actual batch size and sequence length
            batch_size, seq_len, num_classes = outputs.shape

            # Ensure labels match the sequence length of outputs
            labels = labels[:, :seq_len]

            # Calculate accuracy and loss
            for i in range(seq_len):
                curr_output = outputs[:, i, :]  # [batch_size, num_classes]
                curr_labels = labels[:, i]      # [batch_size]

                # Only consider non-padding positions
                mask = curr_labels != -1
                if mask.any():
                    curr_loss = criterion(curr_output[mask], curr_labels[mask])
                    total_loss += curr_loss.item()

                    predictions = curr_output[mask].argmax(dim=1)
                    correct_chars += (predictions == curr_labels[mask]).sum().item()
                    total_chars += mask.sum().item()

    avg_loss = total_loss / (len(val_loader) * seq_len)
    accuracy = correct_chars / total_chars if total_chars > 0 else 0
    return avg_loss, accuracy

def decode_prediction(prediction, idx_to_char):
    """Convert a prediction tensor to a string"""
    return ''.join(idx_to_char[idx.item()] for idx in prediction if idx != -1 and idx.item() < len(idx_to_char))

# Training loop
print("Starting training...")
for epoch in range(num_epochs):
    # Train
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device)

    # Validate
    val_loss, accuracy = validate(model, val_loader, criterion, device)

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Loss: {val_loss:.4f}")
    print(f"Character Accuracy: {accuracy:.4f}")

    # Print examples every few epochs
    if epoch % 2 == 0:
        model.eval()
        with torch.no_grad():
            images, labels = next(iter(val_loader))
            images = images.to(device)
            outputs = model(images)
            predictions = outputs.argmax(dim=2)

            # Print a few examples
            for i in range(min(3, len(images))):
                true_text = decode_prediction(labels[i], idx_to_char)
                pred_text = decode_prediction(predictions[i].cpu(), idx_to_char)
                print(f"\nExample {i+1}:")
                print(f"True: {true_text}")
                print(f"Pred: {pred_text}")

    print("-" * 50)

Starting training...
Epoch 1/10
Train Loss: 2.6928
Val Loss: 2.5491
Character Accuracy: 0.1288

Example 1:
True: cyanopia
Pred: serriiiaieeyssss

Example 2:
True: multipresent
Pred: serriiiiiiiaannn

Example 3:
True: Epitoniidae
Pred: serriiiaeeesssss
--------------------------------------------------
Epoch 2/10
Train Loss: 2.3426
Val Loss: 2.1392
Character Accuracy: 0.2005
--------------------------------------------------
Epoch 3/10
Train Loss: 1.9968
Val Loss: 1.9147
Character Accuracy: 0.2432

Example 1:
True: cyanopia
Pred: periioiaatyyyyyy

Example 2:
True: multipresent
Pred: periiooeeeetttts

Example 3:
True: Epitoniidae
Pred: periioiidaeessss
--------------------------------------------------
Epoch 4/10
Train Loss: 1.8036
Val Loss: 1.7283
Character Accuracy: 0.2755
--------------------------------------------------
Epoch 5/10
Train Loss: 1.6865
Val Loss: 1.6887
Character Accuracy: 0.2889

Example 1:
True: cyanopia
Pred: pereietatyysssss

Example 2:
True: multipresent
Pred: pere

In [28]:
import torch
import numpy as np
from tqdm import tqdm

def evaluate_model(model, test_loader, idx_to_char, device, num_examples=5):
    model.eval()
    total_chars = 0
    correct_chars = 0
    total_words = 0
    correct_words = 0
    examples = []

    print("\nEvaluating trained model...")
    with torch.no_grad():
        for images, labels in tqdm(test_loader):
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)  # [batch_size, seq_len, num_classes]
            batch_size, seq_len, num_classes = outputs.shape

            # Ensure labels match the sequence length of outputs
            labels = labels[:, :seq_len]

            # Get predictions
            predictions = outputs.argmax(dim=2)  # [batch_size, seq_len]

            # Calculate accuracy
            for i in range(seq_len):
                curr_preds = predictions[:, i]
                curr_labels = labels[:, i]

                # Only consider non-padding positions
                mask = curr_labels != -1
                if mask.any():
                    correct_chars += (curr_preds[mask] == curr_labels[mask]).sum().item()
                    total_chars += mask.sum().item()

            # Collect examples
            if len(examples) < num_examples:
                for i in range(len(images)):
                    true_text = decode_prediction(labels[i], idx_to_char)
                    pred_text = decode_prediction(predictions[i].cpu(), idx_to_char)

                    # Calculate per-example accuracy
                    example_correct = 0
                    example_total = 0
                    for j in range(len(true_text)):
                        if j < len(pred_text):
                            if true_text[j] == pred_text[j]:
                                example_correct += 1
                        example_total += 1

                    example_accuracy = example_correct / example_total if example_total > 0 else 0
                    examples.append((true_text, pred_text, example_accuracy))

                    # Word accuracy
                    if true_text == pred_text:
                        correct_words += 1
                    total_words += 1

    # Calculate overall metrics
    char_accuracy = (correct_chars / total_chars * 100) if total_chars > 0 else 0
    word_accuracy = (correct_words / total_words * 100) if total_words > 0 else 0

    print("\nTrained Model Results:")
    print(f"Character Accuracy: {char_accuracy:.2f}%")
    print(f"Word Accuracy (Exact Match): {word_accuracy:.2f}%")

    print("\nExample Predictions:")
    for i, (true_text, pred_text, char_acc) in enumerate(examples, 1):
        print(f"\nExample {i}:")
        print(f"True: {true_text}")
        print(f"Pred: {pred_text}")
        print(f"Character Accuracy: {char_acc*100:.2f}%")

    return char_accuracy, word_accuracy

class BaselinePredictor:
    def __init__(self, train_loader):
        """Initialize baseline predictor with training data statistics"""
        print("\nAnalyzing training data for baseline predictions...")
        self.char_freqs = {}  # Character frequency at each position
        self.avg_length = 0
        self.total_samples = 0

        # Collect statistics from training data
        for _, labels in tqdm(train_loader):
            self.total_samples += len(labels)
            for label_tensor in labels:
                text = decode_prediction(label_tensor, idx_to_char)
                self.avg_length += len(text)

                # Count character frequencies at each position
                for pos, char in enumerate(text):
                    if pos not in self.char_freqs:
                        self.char_freqs[pos] = {}
                    self.char_freqs[pos][char] = self.char_freqs[pos].get(char, 0) + 1

        self.avg_length = round(self.avg_length / self.total_samples)

        # Convert frequencies to most common characters
        self.most_common_chars = {}
        for pos in self.char_freqs:
            self.most_common_chars[pos] = max(self.char_freqs[pos].items(), key=lambda x: x[1])[0]

    def predict(self, batch_size):
        """Generate predictions for a batch"""
        predictions = []
        for _ in range(batch_size):
            pred = ''
            for pos in range(min(self.avg_length, max(self.char_freqs.keys()) + 1)):
                pred += self.most_common_chars.get(pos, 'a')  # Default to 'a' if position not seen
            predictions.append(pred)
        return predictions

def evaluate_baseline(baseline_predictor, test_loader, num_examples=5):
    print("\nEvaluating baseline model...")
    total_chars = 0
    correct_chars = 0
    total_words = 0
    correct_words = 0
    examples = []

    for images, labels in tqdm(test_loader):
        batch_predictions = baseline_predictor.predict(len(images))

        for i in range(len(images)):
            true_text = decode_prediction(labels[i], idx_to_char)
            pred_text = batch_predictions[i]

            # Calculate character accuracy
            min_len = min(len(true_text), len(pred_text))
            max_len = max(len(true_text), len(pred_text))

            correct = sum(1 for j in range(min_len) if true_text[j] == pred_text[j])
            correct_chars += correct
            total_chars += max_len

            # Word accuracy
            if true_text == pred_text:
                correct_words += 1
            total_words += 1

            # Store example
            if len(examples) < num_examples:
                example_accuracy = correct / max_len if max_len > 0 else 0
                examples.append((true_text, pred_text, example_accuracy))

    # Calculate overall metrics
    char_accuracy = (correct_chars / total_chars * 100) if total_chars > 0 else 0
    word_accuracy = (correct_words / total_words * 100) if total_words > 0 else 0

    print("\nBaseline Model Results:")
    print(f"Character Accuracy: {char_accuracy:.2f}%")
    print(f"Word Accuracy (Exact Match): {word_accuracy:.2f}%")

    print("\nExample Predictions:")
    for i, (true_text, pred_text, char_acc) in enumerate(examples, 1):
        print(f"\nExample {i}:")
        print(f"True: {true_text}")
        print(f"Pred: {pred_text}")
        print(f"Character Accuracy: {char_acc*100:.2f}%")

    return char_accuracy, word_accuracy

# Run evaluations
print("Starting final evaluation...")

# Evaluate trained model
model_char_acc, model_word_acc = evaluate_model(
    model, test_loader, idx_to_char, device, num_examples=5
)

# Create and evaluate baseline
baseline_predictor = BaselinePredictor(train_loader)
baseline_char_acc, baseline_word_acc = evaluate_baseline(
    baseline_predictor, test_loader, num_examples=5
)

# Print comparison
print("\nFinal Comparison:")
print("-" * 50)
print(f"{'Metric':<25} {'Trained Model':>15} {'Baseline':>15}")
print("-" * 50)
print(f"{'Character Accuracy':<25} {model_char_acc:>14.2f}% {baseline_char_acc:>14.2f}%")
print(f"{'Word Accuracy':<25} {model_word_acc:>14.2f}% {baseline_word_acc:>14.2f}%")

Starting final evaluation...

Evaluating trained model...


100%|██████████| 313/313 [00:13<00:00, 22.87it/s]



Trained Model Results:
Character Accuracy: 33.03%
Word Accuracy (Exact Match): 0.00%

Example Predictions:

Example 1:
True: superabundance
Pred: sereeedandaneess
Character Accuracy: 50.00%

Example 2:
True: presbytic
Pred: sereeeticsssssss
Character Accuracy: 33.33%

Example 3:
True: sewn
Pred: sereeeeeyyssyyyw
Character Accuracy: 50.00%

Example 4:
True: acoelomatous
Pred: sereeeiatousssss
Character Accuracy: 50.00%

Example 5:
True: afforestable
Pred: sereeeaaableesss
Character Accuracy: 41.67%

Example 6:
True: greener
Pred: sereeeeyyyyyQJqq
Character Accuracy: 28.57%

Example 7:
True: aleph
Pred: sereeeeyyysy6qqq
Character Accuracy: 0.00%

Example 8:
True: Saturnicentric
Pred: sereeeteentticcc
Character Accuracy: 35.71%

Example 9:
True: condescendingnes
Pred: sereeeldddignnes
Character Accuracy: 37.50%

Example 10:
True: unsubstituted
Pred: sereeetttttedddr
Character Accuracy: 38.46%

Example 11:
True: antimonarchial
Pred: sereeeorcchially
Character Accuracy: 35.71%

Example 12:

100%|██████████| 2497/2497 [01:31<00:00, 27.20it/s]



Evaluating baseline model...


100%|██████████| 313/313 [00:11<00:00, 26.97it/s]


Baseline Model Results:
Character Accuracy: 8.81%
Word Accuracy (Exact Match): 0.00%

Example Predictions:

Example 1:
True: superabundance
Pred: sereeiieee
Character Accuracy: 14.29%

Example 2:
True: presbytic
Pred: sereeiieee
Character Accuracy: 0.00%

Example 3:
True: sewn
Pred: sereeiieee
Character Accuracy: 20.00%

Example 4:
True: acoelomatous
Pred: sereeiieee
Character Accuracy: 8.33%

Example 5:
True: afforestable
Pred: sereeiieee
Character Accuracy: 0.00%

Final Comparison:
--------------------------------------------------
Metric                      Trained Model        Baseline
--------------------------------------------------
Character Accuracy                 33.03%           8.81%
Word Accuracy                       0.00%           0.00%



