In [None]:
# !pip install nltk tqdm

import os
import random
import nltk
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms
from tqdm import tqdm
from collections import defaultdict
from nltk.corpus import words

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# from google.colab import files
# uploaded = files.upload()



cuda


In [None]:
nltk.download('words')

def generate_image_with_text(text, width=256, height=64):
    image = Image.new('RGB', (width, height), color='white')
    draw = ImageDraw.Draw(image)
    try:
        font = ImageFont.truetype("/content/arial.ttf", 24)
    except IOError:
        print("Arial font not found. Using default font.")
        font = ImageFont.load_default()

    text_bbox = draw.textbbox((0, 0), text, font=font)
    text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
    position = ((width - text_width) // 2, (height - text_height) // 2)

    draw.text(position, text, fill='black', font=font)
    return np.array(image)

def save_image_dataset(dataset, folder="/content/words/"):
    os.makedirs(folder, exist_ok=True)
    for i, (image_array, word) in enumerate(dataset):
        image = Image.fromarray(image_array)
        filename = f"{folder}/image_{i}_{word}.png"
        image.save(filename)

def generate_and_save():
    word_list = words.words()
    num_samples = 100_000
    dataset = [
        (generate_image_with_text(word), word)
        for word in tqdm(random.choices(word_list, k=num_samples), desc="Generating images", unit="word")
        if word.isalpha()
    ]
    save_image_dataset(dataset)
    print('Images saved to /content/words/')

generate_and_save()


[nltk_data] Downloading package words to /root/nltk_data...
[nltk_data]   Unzipping corpora/words.zip.
Generating images: 100%|██████████| 100000/100000 [02:48<00:00, 593.54word/s]


Images saved to /content/words/


In [11]:
class OCRDataset(Dataset):
    def __init__(self, folder="/content/words/"):
        self.folder = folder
        self.image_files = [f for f in os.listdir(folder) if f.endswith(".png")]
        self.transform = transforms.ToTensor()
        self.encoder = {chr(65 + i): i for i in range(26)} | {chr(97 + i): i + 26 for i in range(26)}

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_filename = self.image_files[idx]
        img_path = os.path.join(self.folder, img_filename)

        image = Image.open(img_path).convert('L')
        image = np.array(image)
        image = np.where(image > 127, 0, 1)
        if self.transform:
            image = self.transform(image)
        image = image.float()

        label = img_filename.split('_', 2)[-1].split('.')[0]
        label_indices = torch.tensor([self.encoder[char] for char in label], dtype=torch.long)
        return image, label_indices


In [12]:
class OCRModel(nn.Module):
    def __init__(self, input_dim=65536, hidden_dim=256, num_classes=53, num_layers=2):
        super(OCRModel, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.flatten = nn.Flatten(start_dim=1)
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.rnn = nn.RNN(hidden_dim, hidden_dim, num_layers, batch_first=True, dropout=0.2)
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = self.cnn(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = x.unsqueeze(1).repeat(1, 32, 1)
        rnn_out, _ = self.rnn(x)
        rnn_out = self.layer_norm(rnn_out)
        output = self.fc(rnn_out)
        return output


In [13]:
def collate_fn(batch):
    images, labels = zip(*batch)
    images = torch.stack(images)
    labels = [torch.tensor(label[:32]) if len(label) > 32 else torch.tensor(label) for label in labels]
    labels_padded = pad_sequence(labels, batch_first=True, padding_value=52)
    if labels_padded.size(1) < 32:
        labels_padded = torch.nn.functional.pad(labels_padded, (0, 32 - labels_padded.size(1)), value=52)
    else:
        labels_padded = labels_padded[:, :32]
    return images, labels_padded


In [14]:
def train_ocr_model(model, train_loader, val_loader, epochs=10, lr=0.001, save_path=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    optimizer = optim.Adam(model.parameters(), lr=lr)

    decoder = {}
    for i in range(26):
        decoder[i] = chr(65 + i)
    for i in range(26, 52):
        decoder[i] = chr(97 + (i - 26))

    decoder[52] = ''

    class_weights = torch.ones(53)
    class_weights[52] = 0.2
    criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))

    model = model.to(device)

    for epoch in range(epochs):
        model.train()
        total_train_loss = 0
        for images, labels in tqdm(train_loader):
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs.permute(0, 2, 1), labels)
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

        model.eval()
        correct_chars = 0
        total_chars = 0
        sample_count = 0
        random_correct_chars = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device)

                outputs = model(images)
                predictions = torch.argmax(outputs, dim=2)
                random_predictions = labels.clone()
                for i in range(random_predictions.size(0)):
                    for j in range(random_predictions.size(1)):
                        if random_predictions[i, j] != 52:
                            random_predictions[i, j] = random.randint(0, 51)


                if sample_count < 5:
                    for pred, label in zip(predictions, labels):
                        if sample_count >= 5:
                            break
                        decoded_pred = ''.join([decoder[p.item()] for p in pred])
                        decoded_label = ''.join([decoder[l.item()] for l in label])
                        print(f"Predicted: {decoded_pred}", end=', ')
                        print(f"Actual: {decoded_label}")
                        sample_count += 1
                for pred, label in zip(predictions, labels):
                    non_null_mask = label != 52

                    correct_chars += (pred[non_null_mask] == label[non_null_mask]).sum().item()
                    total_chars += non_null_mask.sum().item()

                for random_pred, label in zip(random_predictions, labels):
                    non_null_mask = label != 52
                    random_correct_chars += (random_pred[non_null_mask] == label[non_null_mask]).sum().item()


        avg_correct_chars = correct_chars / total_chars
        avg_train_loss = total_train_loss / len(train_loader)
        random_baseline_acc = random_correct_chars / total_chars

        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Validation ANCC: {avg_correct_chars:.4f}, '
              f'Random Baseline Accuracy: {random_baseline_acc:.4f}')

    print("Training complete.")

    if save_path is not None:
        torch.save(model, save_path)

In [15]:
def run_ocr():
    model = OCRModel(num_classes=53)
    dataset = OCRDataset(folder="/content/words/")
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

    print(f"Train samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")
    train_ocr_model(model, train_loader, val_loader)

run_ocr()

Train samples: 80000, Validation samples: 20000


  labels = [torch.tensor(label[:32]) if len(label) > 32 else torch.tensor(label) for label in labels]
100%|██████████| 2500/2500 [02:06<00:00, 19.76it/s]


Predicted: poiassism, Actual: potassium
Predicted: unaaalaale, Actual: unwastable
Predicted: camplolize, Actual: campholide
Predicted: Chanabal, Actual: Chanabal
Predicted: Midees, Actual: Midwest
Epoch 1/10, Train Loss: 1.5848, Validation ANCC: 0.6248, Random Baseline Accuracy: 0.0196


100%|██████████| 2500/2500 [02:02<00:00, 20.36it/s]


Predicted: potassium, Actual: potassium
Predicted: unwastable, Actual: unwastable
Predicted: campholide, Actual: campholide
Predicted: Chanabal, Actual: Chanabal
Predicted: Midwest, Actual: Midwest
Epoch 2/10, Train Loss: 0.5152, Validation ANCC: 0.8681, Random Baseline Accuracy: 0.0189


100%|██████████| 2500/2500 [02:02<00:00, 20.35it/s]


Predicted: potassium, Actual: potassium
Predicted: unwastable, Actual: unwastable
Predicted: campholide, Actual: campholide
Predicted: Chanabal, Actual: Chanabal
Predicted: Midwest, Actual: Midwest
Epoch 3/10, Train Loss: 0.2518, Validation ANCC: 0.9148, Random Baseline Accuracy: 0.0192


100%|██████████| 2500/2500 [02:02<00:00, 20.40it/s]


Predicted: potassium, Actual: potassium
Predicted: unwastable, Actual: unwastable
Predicted: campholide, Actual: campholide
Predicted: Chanabal, Actual: Chanabal
Predicted: Miwwest, Actual: Midwest
Epoch 4/10, Train Loss: 0.1816, Validation ANCC: 0.9387, Random Baseline Accuracy: 0.0190


100%|██████████| 2500/2500 [02:03<00:00, 20.29it/s]


Predicted: potassium, Actual: potassium
Predicted: unwastable, Actual: unwastable
Predicted: campholide, Actual: campholide
Predicted: Chanabal, Actual: Chanabal
Predicted: Midwest, Actual: Midwest
Epoch 5/10, Train Loss: 0.1504, Validation ANCC: 0.9487, Random Baseline Accuracy: 0.0192


100%|██████████| 2500/2500 [02:02<00:00, 20.33it/s]


Predicted: potassium, Actual: potassium
Predicted: unwastable, Actual: unwastable
Predicted: campholide, Actual: campholide
Predicted: Chanabal, Actual: Chanabal
Predicted: Midwest, Actual: Midwest
Epoch 6/10, Train Loss: 0.1364, Validation ANCC: 0.9441, Random Baseline Accuracy: 0.0194


100%|██████████| 2500/2500 [02:03<00:00, 20.24it/s]


Predicted: potassium, Actual: potassium
Predicted: unwastable, Actual: unwastable
Predicted: campholide, Actual: campholide
Predicted: Chanabal, Actual: Chanabal
Predicted: Midwest, Actual: Midwest
Epoch 7/10, Train Loss: 0.1225, Validation ANCC: 0.9578, Random Baseline Accuracy: 0.0189


100%|██████████| 2500/2500 [02:02<00:00, 20.35it/s]


Predicted: potassium, Actual: potassium
Predicted: unwastable, Actual: unwastable
Predicted: campholide, Actual: campholide
Predicted: Chanabal, Actual: Chanabal
Predicted: Midwest, Actual: Midwest
Epoch 8/10, Train Loss: 0.1141, Validation ANCC: 0.9655, Random Baseline Accuracy: 0.0191


100%|██████████| 2500/2500 [02:02<00:00, 20.38it/s]


Predicted: potassium, Actual: potassium
Predicted: unwastable, Actual: unwastable
Predicted: campholide, Actual: campholide
Predicted: Chanabal, Actual: Chanabal
Predicted: Midwest, Actual: Midwest
Epoch 9/10, Train Loss: 0.1062, Validation ANCC: 0.9453, Random Baseline Accuracy: 0.0194


100%|██████████| 2500/2500 [02:02<00:00, 20.37it/s]


Predicted: potassium, Actual: potassium
Predicted: unwastable, Actual: unwastable
Predicted: campholide, Actual: campholide
Predicted: Chanabal, Actual: Chanabal
Predicted: Midwest, Actual: Midwest
Epoch 10/10, Train Loss: 0.1018, Validation ANCC: 0.9552, Random Baseline Accuracy: 0.0195
Training complete.
