#### 1. Import Dependencies

In [None]:
import os
import re
import random
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader, TensorDataset

import torchvision.transforms as transforms
import torchvision.transforms.functional as F

from sklearn.model_selection import train_test_split
import xml.etree.ElementTree as ET

#### 2. Define vocabularies for model inputs and output

In [None]:
abbreviations = {'EIN':'I', 'HE':'E', 'SAD':'A', 'SIN':'S', 'TA':'X', 'ZH':'W', 'TH':'G', 'SH':'U'}

characters = {'0': 0,'1': 1,'2': 2,'3': 3,'4': 4,'5': 5,'6': 6,'7': 7,'8': 8,'9': 9,
              'B': 10,'D': 11,'EIN': 12,'H': 13,'HE': 14,'J': 15, 'L':16, 'M':17, 'N': 18,
             'P': 19,'Q': 20,'SAD': 21,'SIN': 22,'T': 23,'TA': 24,'V': 25,'Y': 26, 'Z': 27,
             'SH': 28, 'TH':29, 'ZH': 30, 'A': 31, ' ': 32}

en_fa = {'0': 0,'1': 1,'2': 2,'3': 3,'4': 4,'5': 5,'6': 6,'7': 7,'8': 8,'9': 9,
         'ز': 'Z','ش': 'SH','ط': 'TA','پ': 'P','ث': 'TH','ژ (معلولین و جانبازان)': 'ZH',
         'الف': 'A','ع': 'EIN','ه‍': 'H','ق': 'Q','ت': 'T','م': 'M','ل': 'L','د': 'D',
         'ی': 'Y','ب': 'B', 'ج': 'J', 'ن': 'N'}

indexes = {'0': 0,'1': 1,'2': 2,'3': 3,'4': 4,'5': 5,'6': 6,'7': 7,'8': 8,'9': 9,
          'B': 10,'D': 11,'I': 12,'H': 13,'E': 14,'J': 15, 'L':16, 'M':17, 'N': 18,
         'P': 19,'Q': 20,'A': 21,'S': 22,'T': 23,'X': 24,'V': 25,'Y': 26, 'Z': 27,
          'U': 28,'G': 29,'W': 30,'A': 31,' ': 32}

rev_indexes = {v: k for k, v in indexes.items()}
n_classes = len(indexes) + 1  # +1 for CTC blank
blank_index = n_classes - 1 # The last char is blank

#### 3. Define the path to your data and image transform function

In [None]:
images_path = os.path.join("OCR-Data/SelectedData/images/")
labels_path = os.path.join("OCR-Data/SelectedData/labels/")

class AddGaussianNoise(object):
    def __init__(self, mean=0., std=0.01):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean

    def __repr__(self):
        return self.__class__.__name__ + f'(mean={self.mean}, std={self.std})'

transform = transforms.Compose([
    transforms.RandomAffine(degrees=0, translate=(0.05,0.1)),
    transforms.ColorJitter(brightness=0.3, contrast=0.3),
    transforms.RandomPerspective(distortion_scale=0.1, p=0.5),
    transforms.Resize((32, 128)),
    transforms.ToTensor(),
    AddGaussianNoise(0., 0.01),
    transforms.Normalize((0.5,), (0.5,))
])

#### 4. Loading data
##### 4.1. Loading generated data

In [None]:
gen_path = "OCR-Data/SelectedData/gen/"

for img in os.listdir(gen_path):
    full_name = img.split('.')[0]
    file_name = re.split("_|-|\.", img)
    idx = file_name[0]
    if file_name[2] in abbreviations:
        file_name[2] = abbreviations[file_name[2]]
    name = ''.join(file_name[1:-1])
    with open(labels_path+full_name+'.txt', 'w') as f:
        f.write(name)

##### 4.2. Loading real data

In [None]:
real_path = os.path.join("OCR-Data/SelectedData/img/")
xmls_path = os.path.join("OCR-Data/SelectedData/xml/")

real = sorted(os.listdir(real_path))
xmls = sorted(os.listdir(xmls_path))

for xml in xmls:
    file_name = xml.split(".")[0]
    tree = ET.parse(xmls_path + xml)
    root = tree.getroot()
    
    name = root.findall(".//name")
    file = []

    for n in name:
        temp = en_fa[n.text]
        if temp in abbreviations:
            temp = abbreviations[temp]
        file.append(str(temp))
    with open(labels_path+file_name+'.txt', 'w') as f:
        f.write(''.join(file))

##### 4.3. Loading data into tensor

In [None]:
images_list = sorted(os.listdir(images_path))
labels_list = sorted(os.listdir(labels_path))

images = []
labels = []

for label in labels_list:
    with open(labels_path + label, 'r') as f:
        idx_values = []
        values = list(f.read())
        for v in values:
            idx_values.append(indexes[v])
        labels.append(idx_values)
        
for img in images_list:
    image = Image.open(images_path + img).convert('RGB')
    image = transform(image)
    images.append(image)
    
    
for i in range(len(labels)-1, -1, -1):
    if len(labels[i]) != 8:
        del labels[i]
        del images[i]
        
images = torch.stack(images)
labels = torch.tensor(labels)

#### 5. Creating CRNN Model ---> 2D-CONV + LSTM + CTC-LOSS

In [None]:
class CRNN(nn.Module):
    def __init__(self, imgH, nc, nclass, nh, dropout_prob=0.3):
        super(CRNN, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(nc, 64, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d(2, 2),
            nn.Dropout(dropout_prob),
            nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d(2, 2),
            nn.Dropout(dropout_prob),
            nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(True),
            nn.Dropout(dropout_prob),
            nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d((2, 1), (2, 1)),
            nn.Dropout(dropout_prob),
            nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),
            nn.Dropout(dropout_prob),
            nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True), nn.MaxPool2d((2, 1), (2, 1)),
            nn.Dropout(dropout_prob),
            nn.Conv2d(512, 512, 2, 1, 0), nn.ReLU(True)
        )

        self.rnn1 = nn.LSTM(512, nh, bidirectional=True)
        self.rnn2 = nn.LSTM(nh * 2, nh, bidirectional=True)
        self.dropout_rnn = nn.Dropout(dropout_prob)
        self.embedding = nn.Linear(nh * 2, nclass)

    def forward(self, x):
        conv = self.cnn(x)
        b, c, h, w = conv.size()
        assert h == 1, f"Unexpected height: {h}"
        conv = conv.squeeze(2)
        conv = conv.permute(2, 0, 1)

        recurrent, _ = self.rnn1(conv)
        recurrent, _ = self.rnn2(recurrent)
        recurrent = self.dropout_rnn(recurrent)
        output = self.embedding(recurrent)
        return(output)


#### 7. Split data into train and validation

In [None]:
images_train, images_val, labels_train, labels_val = train_test_split(images, labels, test_size=0.2, random_state=42)

#### 8. Compile Model on GPU/CUDA

In [None]:
model = CRNN(imgH=32, nc=3, nclass=n_classes, nh=256)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

train_dataset = TensorDataset(images_train, labels_train)
val_dataset = TensorDataset(images_val, labels_val)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

ctc_loss = nn.CTCLoss(blank=blank_index, zero_infinity=True)
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)

In [None]:
num_epochs = 10
for epoch in range(num_epochs):
    
    model.train()
    train_loss = 0
    for images_batch, labels_batch in train_loader:
        images_batch, labels_batch = images_batch.to(device), labels_batch.to(device)

        optimizer.zero_grad()
        output = model(images_batch)
        input_lengths = torch.full((labels_batch.size(0),), output.size(0), dtype=torch.long).to(device)
        target_lengths = torch.full((labels_batch.size(0),), labels_batch.size(1), dtype=torch.long).to(device)
        
        loss = ctc_loss(output.log_softmax(2), labels_batch, input_lengths, target_lengths)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images_batch, labels_batch in val_loader:
            images_batch, labels_batch = images_batch.to(device), labels_batch.to(device)
            output = model(images_batch)
            input_lengths = torch.full((labels_batch.size(0),), output.size(0), dtype=torch.long).to(device)
            target_lengths = torch.full((labels_batch.size(0),), labels_batch.size(1), dtype=torch.long).to(device)
            loss = ctc_loss(output.log_softmax(2), labels_batch, input_lengths, target_lengths)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

#### 9. Decode result of network

In [None]:
def decode_prediction(pred):
    pred = pred.permute(1, 0, 2)
    pred_labels = torch.argmax(pred, dim=2)
    results = []
    for seq in pred_labels:
        prev = -1
        text = ''
        for p in seq:
            p = p.item()
            if p != blank_index and p != prev:
                text += rev_indexes.get(p, '')
            prev = p
        results.append(text)
    return results

#### 10. Evaluation

In [None]:
model.eval()
with torch.no_grad():
    image = images[1].unsqueeze(0).to(device)
    output = model(image)
    decoded = decode_prediction(output)
    print("Predicted:", decoded[0])
    print("Ground truth:", ''.join([rev_indexes[v.item()] for v in labels[0]]))

#### 11. Saving weights of network

In [None]:
torch.save(model.state_dict(), 'OCRModel/crnn_weights.pth')