In [12]:
import os
import shutil
import torch
import torch.nn.functional as F
import random
from torch import nn
from tqdm import tqdm
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms

In [13]:
with open('data/raw/ittk5/alphabet.txt') as f:
    alphabet = f.readline()
print(alphabet)    

# Map the characters in the alphabet to the index
alphabet_map = {}
for i, char in enumerate(alphabet):
    # The index of blank in CTCLoss should be zero.
    # The first one in the alphabet has been left blank,
    # and there is no need for special operation here
    alphabet_map[char] = i
print(alphabet_map)

 0123456789.-+ABCDEFGHIJKLMNOPQRSTUVWXYZ/\abcdefghijklmnopqrstuvwxyz,!@#$%^&*()?:;'"~`
{' ': 0, '0': 1, '1': 2, '2': 3, '3': 4, '4': 5, '5': 6, '6': 7, '7': 8, '8': 9, '9': 10, '.': 11, '-': 12, '+': 13, 'A': 14, 'B': 15, 'C': 16, 'D': 17, 'E': 18, 'F': 19, 'G': 20, 'H': 21, 'I': 22, 'J': 23, 'K': 24, 'L': 25, 'M': 26, 'N': 27, 'O': 28, 'P': 29, 'Q': 30, 'R': 31, 'S': 32, 'T': 33, 'U': 34, 'V': 35, 'W': 36, 'X': 37, 'Y': 38, 'Z': 39, '/': 40, '\\': 41, 'a': 42, 'b': 43, 'c': 44, 'd': 45, 'e': 46, 'f': 47, 'g': 48, 'h': 49, 'i': 50, 'j': 51, 'k': 52, 'l': 53, 'm': 54, 'n': 55, 'o': 56, 'p': 57, 'q': 58, 'r': 59, 's': 60, 't': 61, 'u': 62, 'v': 63, 'w': 64, 'x': 65, 'y': 66, 'z': 67, ',': 68, '!': 69, '@': 70, '#': 71, '$': 72, '%': 73, '^': 74, '&': 75, '*': 76, '(': 77, ')': 78, '?': 79, ':': 80, ';': 81, "'": 82, '"': 83, '~': 84, '`': 85}


In [14]:
import os
import shutil

def rename_and_save_images(mode):
    if mode == 'train':
        img_dir = 'data/raw/ittk5/train/img'
        label_dir = 'data/raw/ittk5/train/label'
        dest_dir = 'data/processed/ittk5/train/'
    elif mode == 'test':
        img_dir = 'data/raw/ittk5/test/img'
        label_dir = 'data/raw/ittk5/test/label'
        dest_dir = 'data/processed/ittk5/test/'

    os.makedirs(dest_dir, exist_ok=True)

    for img_name in os.listdir(img_dir):
        if img_name.endswith('.jpg'):
            base_name = os.path.splitext(img_name)[0]
            label_file = os.path.join(label_dir, f"{base_name}.txt")
            if os.path.exists(label_file):
                try:
                    with open(label_file, 'r') as f:
                        label = f.read().strip()
                    # Sanitize the label to remove slashes and other invalid characters
                    sanitized_label = label.replace('/', '_')
                    new_img_name = f"{base_name}_{sanitized_label}.jpg"
                    shutil.copy(
                        os.path.join(img_dir, img_name),
                        os.path.join(dest_dir, new_img_name)
                    )
                except Exception as e:
                    print(f"Skipping {img_name}: {e}")
                    continue

rename_and_save_images(mode='train')
rename_and_save_images(mode='test')

In [15]:
class MyDataset(Dataset):
    """Create dataset inherited from torch.utils.data.Dataset
    
    Attributes:
        data_dir: train dir or test dir.
        alphabet_map: The map from char to index.
        img_names: File names of all image under the data_dir.
        lables: Labels of all image under the data_dir.
        trans: Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255]
        to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
    
    """
    def __init__(self, data_dir):
        """Inits dataset"""
        self.data_dir = data_dir
        self.alphabet_map = alphabet_map
        self.img_names = os.listdir(self.data_dir)
        self.labels = [i.split('_')[1].split('.')[0] for i in self.img_names]
        #print(f'label: {self.labels}')
        self.trans = transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.Resize((32, 128)),
            transforms.RandomRotation(5),
            transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
            transforms.ToTensor()
        ])
        
    def __getitem__(self, idx):
        """Get single image by idx
        
        Args:
            idx: index
            
        Returns:
            img: torch.FloatTensor
            label: Actual lable of the image, like "ZOW-PRF-LFB".
        """
        img_path = os.path.join(self.data_dir, self.img_names[idx])
        img = Image.open(img_path)
        img = self.trans(img)
        label = self.labels[idx]
        #print(f'label: {label}')
        return img, label
        
    def __len__(self):
        return len(self.labels)

    
class BiLSTM(nn.Module):
    """ Bidirectional LSTM and embedding layer.
    
    Attributes:
        rnn: Bidirectional LSTM
        linear: Embedding layer
    """
    def __init__(self, num_input, num_hiddens, num_output):
        super().__init__()
        self.rnn = nn.LSTM(num_input, num_hiddens, bidirectional=True)
        # the size of input of embedding layer should mutiply by 2, because of the bidirectional.
        self.linear = nn.Linear(num_hiddens * 2, num_output)  
    
    def forward(self, X):
        rnn_out, _ = self.rnn(X)
        T, b, h = rnn_out.size()  # T: time step, b: batch size, h: hidden size
        t_rec = rnn_out.view(T * b, h)
        output = self.linear(t_rec)
        output = output.view(T, b, -1)
        return output


class CRNN(nn.Module):
    def __init__(self, num_class):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),  # Height: 32 -> 16
            nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),  # Height: 16 -> 8
            nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1)),  # Height: 8 -> 4
            nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1)),  # Height: 4 -> 2
            nn.Conv2d(512, 512, kernel_size=(2, 2), stride=(1, 1), padding=(0, 0)),  # Height: 2 -> 1
            nn.BatchNorm2d(512),
            nn.ReLU()
        )

        self.rnn = nn.Sequential(
            BiLSTM(512, 256, 256),
            BiLSTM(256, 256, num_class)
        )
        
    def forward(self, X):
        cnn_out = self.cnn(X)  # cnn_out shape: (batch_size x channel x height x width)
        assert cnn_out.shape[2] == 1, "the height of conv must be 1"
        cnn_out = cnn_out.squeeze(2)  # squeeze the dim 2 (height) of cnn_out
        cnn_out = cnn_out.permute(2, 0, 1)  # move the width to the first dim, as the time step of rnn input
        output = self.rnn(cnn_out)  # output shape: (time step x batch_size x num_class)
        output = F.log_softmax(output, dim=2)  # do softmax at the dim of num_class
        return output
        

In [16]:
train_set = MyDataset(data_dir='data/processed/ittk5/train')
batch_size = 64
trainloader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, drop_last=True)

# Check if the input and output shapes meet expectations
for X, y in trainloader:
    break
print('input shape:', X.shape)
crnn = CRNN(num_class=len(alphabet))
#print(crnn)
preds = crnn(X)
print('output shape from CRNNnet:', preds.shape)

input shape: torch.Size([64, 1, 32, 128])
output shape from CRNNnet: torch.Size([33, 64, 86])


In [17]:
def get_ctcloss_parameters(text_batch):
    """Convert the real text batch into three parameters required by ctcloss,
    encoded text/predict length/real length
    
    Args:
        text_batch: real text batch, like('E-Z-4', 'EMD-6-04')
        
    Returns:
        encoded_text: encode text by alphabet_map 
        preds_length: (time step x batch_size) => (51 * batch_size)
        actual_length: length of text to index，max(len(text)) * batch_size
    """
    actual_length = []
    result = []
    for item in text_batch:            
        actual_length.append(len(item))
        r = []
        for char in item:
            index = alphabet_map[char]
            r.append(index)
        result.append(r)

    max_len = 0
    for r in result:
        if len(r) > max_len:
            max_len = len(r)

    result_temp = []
    for r in result:
        for i in range(max_len - len(r)):
            r.append(0)
        result_temp.append(r)

    encoded_text = result_temp
    encoded_text = torch.LongTensor(encoded_text)
    preds_length = torch.LongTensor([preds.size(0)] * batch_size)
    actual_length = torch.LongTensor(actual_length)
    return encoded_text, preds_length, actual_length

In [18]:
use_gpu = True
num_epoch = 100

if use_gpu:
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
crnn.train()
trainer = torch.optim.Adam(crnn.parameters(), lr=0.001)
loss = nn.CTCLoss(zero_infinity=True)
crnn = crnn.to(device)
loss = loss.to(device)
    
for epoch in range(num_epoch):
    for X, y in trainloader:
        X = X.to(device)
        trainer.zero_grad()
        preds = crnn(X) 
        encoded_text, preds_length, actual_length = get_ctcloss_parameters(y)

        encoded_text = encoded_text.to(device)
        preds_length = preds_length.to(device)
        actual_length = actual_length.to(device)
        
        l = loss(preds, encoded_text,preds_length, actual_length) / batch_size
        l.backward()
        trainer.step()
    print('epoch', str(epoch + 1).ljust(10), 'loss:', format(l.item(), '.6f'))

epoch 1          loss: 0.069445
epoch 2          loss: 0.067837
epoch 3          loss: 0.068413
epoch 4          loss: 0.067885
epoch 5          loss: 0.066891
epoch 6          loss: 0.065355
epoch 7          loss: 0.065597
epoch 8          loss: 0.065623
epoch 9          loss: 0.065407
epoch 10         loss: 0.063945
epoch 11         loss: 0.063161
epoch 12         loss: 0.062921
epoch 13         loss: 0.062178
epoch 14         loss: 0.062293
epoch 15         loss: 0.057206
epoch 16         loss: 0.058468
epoch 17         loss: 0.056464
epoch 18         loss: 0.059342
epoch 19         loss: 0.057968
epoch 20         loss: 0.057134
epoch 21         loss: 0.056437
epoch 22         loss: 0.056523
epoch 23         loss: 0.049896
epoch 24         loss: 0.051441
epoch 25         loss: 0.048132
epoch 26         loss: 0.047355
epoch 27         loss: 0.046204
epoch 28         loss: 0.044522
epoch 29         loss: 0.037848
epoch 30         loss: 0.038982
epoch 31         loss: 0.034903
epoch 32

In [19]:
def get_final_pred(text):
    """Remove adjacent duplicate characters

    Args:
        text: Do argmax after crnn net ouput
        
    Returns:
        final_text: Text removed adjacent duplicate characters
    """
    text = list(text)
    for i in range(len(text)):
        for j in range(i + 1, len(text)):
            if text[j] == ' ':
                break
            else:
                if text[j] == text[i]:
                    text[j] = ' '
                else:
                    continue
    final_text = ''.join(text).replace(' ', '')
    return final_text

def predict(net, X, y):
    """Predict batch images, print predict result and ground truth.
    
    Args:
        net: crnn net
        X: batch images
        y: batch actual texts
    """
    preds = net(X)
    _, preds = preds.max(2)
    idx = 0
    print('crnn net output'.ljust(51), '|', 'final predict'.ljust(20), '|', 'ground truth'.ljust(20))
    print('=' * 99)
    for pred in preds.permute(1, 0):
        pred_text = ''.join([alphabet[i.item()] for i in pred])
        print(pred_text, '|', get_final_pred(pred_text).ljust(20), '|', y[idx].ljust(20))
        print('·' * 99)
        idx += 1

In [31]:
test_set = MyDataset(data_dir='data/processed/ittk5/test')

# predict single image with random index
idx = random.randint(0, len(test_set) - 1)
X, y = test_set[idx]
X = X.unsqueeze(0) # add dim as batch
y = [y]
X = X.to(device)
predict(crnn, X, y)
print('\n' * 2)
# predict batch using dataloader
testloader = DataLoader(test_set, batch_size=8, shuffle=True, drop_last=True)
X, y = next(iter(testloader))
X = X.to(device)
predict(crnn, X, y)

crnn net output                                     | final predict        | ground truth        
Gaaaaaaaa9aaaaQaaaaaa9aaaaUaaaaaa | Ga9Q9U               | State               
···································································································



crnn net output                                     | final predict        | ground truth        
KaaaaaaaaaCaaaaaFaaaaazaaaraaaaaa | KaCFzr               | WORLD               
···································································································
gaaaaaakaaaaaabaaaaaaaiaaaaajaaaa | gakbij               | SOUTH               
···································································································
caaaaaaaaCaaaaaaagaaaaaaaagaaaaaa | caCgg                | JOE'S               
···································································································
qaaaaaaaaaaQaaaaaaaa7aaaaaUaaaaaa | qaQ7U                | Cafe                
·

In [27]:
from PIL import Image
import torch
from torchvision import transforms

# Define the transformation
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((32, 128)),  # Adjust height to 32 or as needed
    transforms.ToTensor()
])

def load_image(image_path):
    # Load the image
    image = Image.open(image_path)
    # Apply the transformations
    image = transform(image)
    # Add a batch dimension
    image = image.unsqueeze(0)
    return image

def decode(output, alphabet):
    # Assuming output is a tensor of shape (time_step, batch_size, num_class)
    output = output.permute(1, 0, 2)  # Change to (batch_size, time_step, num_class)
    output = output.squeeze(0)  # Remove batch dimension
    _, max_indices = torch.max(output, dim=1)
    
    print(f'Max indices: {max_indices}')  # Debugging: print the max indices
    
    # Convert indices to characters and handle repeated characters
    predicted_text = []
    prev_idx = None
    for idx in max_indices:
        if idx != prev_idx and idx != 0:  # Skip repeated characters and blank character (assuming blank is index 0)
            predicted_text.append(alphabet[idx])
        prev_idx = idx
    
    return ''.join(predicted_text)

def predict_single_image(model, image_path, alphabet):
    # Load and preprocess the image
    image = load_image(image_path)
    # Move the image to the device
    image = image.to(device)
    # Set the model to evaluation mode
    model.eval()
    with torch.no_grad():
        # Get the predictions
        output = model(image)
        print(f'Model output: {output}')  # Debugging: print the model output
        print(f'Model output shape: {output.shape}')  # Debugging: print the output shape
        # Decode the predictions
        prediction = decode(output, alphabet)
    return prediction

# Example usage
alphabet = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 '  # Define your alphabet including a blank character
image_path = 'data/raw/character_set1/Test_1.png'
prediction = predict_single_image(crnn, image_path, alphabet)
print('Prediction:', prediction)

Model output: tensor([[[-6.1337e+00, -7.4148e+00, -7.6299e-01,  ..., -6.6640e+00,
          -1.1636e+01, -1.1873e+01]],

        [[-7.2000e-05, -1.4874e+01, -1.4837e+01,  ..., -1.6584e+01,
          -1.7709e+01, -1.6164e+01]],

        [[-2.3842e-06, -1.8139e+01, -1.7818e+01,  ..., -2.0427e+01,
          -2.1621e+01, -1.9644e+01]],

        ...,

        [[-1.1921e-06, -1.8024e+01, -1.8447e+01,  ..., -2.1068e+01,
          -2.1932e+01, -1.9875e+01]],

        [[-2.5034e-06, -1.7182e+01, -1.7975e+01,  ..., -2.0359e+01,
          -2.1145e+01, -1.9072e+01]],

        [[-1.2755e-05, -1.5133e+01, -1.6552e+01,  ..., -1.8425e+01,
          -1.8884e+01, -1.6827e+01]]], device='cuda:0')
Model output shape: torch.Size([33, 1, 86])
Max indices: tensor([2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
Prediction: c
