In [None]:
import nltk
from nltk.corpus import words
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import os

# Download the words corpus if not already downloaded
def get_wordslist():
    nltk.download('words')

    word_list = words.words()

    np.random.shuffle(word_list)

    word_list_100k = word_list[:100000]
    
    return word_list_100k


def generate_images(word, width=256, height=64, font_size=36):
    image = Image.new('L', (width, height), 255)
    draw = ImageDraw.Draw(image)
    
    try:
        font = ImageFont.truetype("DejaVuSansMono.ttf", font_size)
    except:
        font = ImageFont.load_default()
    
    bbox = draw.textbbox((0, 0), word, font=font)
    text_width = bbox[2] - bbox[0]
    text_height = bbox[3] - bbox[1]
    
    x = (width - text_width) // 2
    y = (height - text_height) // 2
    
    draw.text((x, y), word, font=font, fill=0)
    
    return np.array(image)


word_list = get_wordslist()

if not os.path.exists('../../data/external/WordImages'):
    os.makedirs('../../data/external/WordImages')

for i, word in enumerate(word_list):
    image = generate_images(word)
    Image.fromarray(image).save(f'../../data/external/WordImages/{word}.png')

print('Done!')

[nltk_data] Downloading package words to /home/keshava/nltk_data...
[nltk_data]   Package words is already up-to-date!


In [2]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image

class WordImagesDataset(Dataset):
    def __init__(self, root, transform=None, characters='ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789- '):
        self.root = root
        self.transform = transform
        self.images = os.listdir(root)
        
        self.char_to_idx = {char: idx + 1 for idx, char in enumerate(characters)}
        self.idx_to_char = {idx: char for char, idx in self.char_to_idx.items()}
    
    def encode_label(self, label):
        return torch.tensor([self.char_to_idx[char] for char in label], dtype=torch.long)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image_path = os.path.join(self.root, self.images[idx])
        image = Image.open(image_path).convert("L")
        
        if self.transform:
            image = self.transform(image)
        
        label = self.images[idx].split('.')[0]
        encoded_label = self.encode_label(label)
        return image, encoded_label, len(encoded_label)

# Transform for images
transform = transforms.Compose([
    transforms.ToTensor(),
])

dataset = WordImagesDataset('../../data/external/WordImages', transform=transform)

# Split the dataset
train_size = int(0.8 * len(dataset))
val_size = (len(dataset) - train_size) // 2
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Collate function to handle variable-length labels
def collate_fn(batch):
    images, labels, label_lengths = zip(*batch)
    images = torch.stack(images)
    labels = torch.cat(labels)
    label_lengths = torch.tensor(label_lengths, dtype=torch.long)
    return images, labels, label_lengths

# Data loaders with collate function
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

print(f'Training set size: {len(train_dataset)}')
print(f'Validation set size: {len(val_dataset)}')
print(f'Test set size: {len(test_dataset)}')

Training set size: 82128
Validation set size: 10266
Test set size: 10266


In [None]:
# Define the CRNN model
class CRNN(nn.Module):
    def __init__(self, num_features=256, hidden_size=128, vocab_size=54, dropout_rate=0.3):
        super(CRNN, self).__init__()
        self.CNN = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.fc1 = nn.Linear(16 * 64, num_features) 
        self.lstm = nn.LSTM(num_features, hidden_size, batch_first=True, bidirectional=True, dropout=dropout_rate)
        self.fc2 = nn.Linear(hidden_size * 2, vocab_size)

    def forward(self, x):
        batch_size = x.size(0)
        x = self.CNN(x)
        x = x.view(batch_size, x.size(1), -1)
        x = self.fc1(x)
        x, _ = self.lstm(x)
        x = self.fc2(x)
        return x

# Average Number of Correct Characters
def calculate_correct_characters(preds, targets):
    correct = 0
    for pred, target in zip(preds, targets):
        pred = [p for p in pred if p != 0]
        target = [t.item() for t in target if t != 0]
        correct += sum([1 if p == t else 0 for p, t in zip(pred, target)])
    return correct


# Training loop
def train_ctc(model, train_loader, val_loader, optimizer, num_epochs=10):
    ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for images, targets, target_lengths in train_loader:
            images, targets, target_lengths = images.to(device), targets.to(device), target_lengths.to(device)

            # Forward pass
            outputs = model(images).log_softmax(2).permute(1, 0, 2)
            
            input_lengths = torch.full(
                size=(outputs.size(1),),
                fill_value=outputs.size(0),
                dtype=torch.long
            ).to(device)

            loss = ctc_loss(outputs, targets, input_lengths, target_lengths)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        train_loss /= len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0
        total_correct_characters = 0
        total_characters = 0

        with torch.no_grad():
            for images, targets, target_lengths in val_loader:
                images, targets, target_lengths = images.to(device), targets.to(device), target_lengths.to(device)
                outputs = model(images).log_softmax(2).permute(1, 0, 2)

                # Calculate input lengths as before
                input_lengths = torch.full(
                    size=(outputs.size(1),),
                    fill_value=outputs.size(0),
                    dtype=torch.long
                ).to(device)

                loss = ctc_loss(outputs, targets, input_lengths, target_lengths)
                val_loss += loss.item()

                # Decode the predicted sequence
                _, predicted_indices = torch.max(outputs, 2)
                predicted_indices = predicted_indices.permute(1, 0)

                correct_characters = calculate_correct_characters(predicted_indices, targets)
                total_correct_characters += correct_characters
                total_characters += sum(target_lengths).item()

            val_loss /= len(val_loader)
            avg_correct_characters = total_correct_characters / total_characters if total_characters > 0 else 0


        print(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Average Correct Characters: {avg_correct_characters:.4f}')

        
# Instantiate model and start training
num_classes = len(dataset.char_to_idx) + 1
model = CRNN(vocab_size=num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train_ctc(model, train_loader, val_loader, optimizer, num_epochs=10)
