### Import libraries

In [None]:
import os
import torch
import torch.nn as nn
import cv2
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import pandas as pd
from torchvision import transforms

### Classes


In [None]:
class CRNN(nn.Module):
    def __init__(self, cnn, lstm, fc):
        super().__init__()
        self.cnn = cnn
        self.lstm = lstm
        self.fc = fc

    def forward(self, x):
        x = self.cnn(x)
        b, c, h, w = x.size()

        x = x.permute(3, 0, 1, 2)
        x = x.reshape(w, b, h * c)

        x, _ = self.lstm(x)
        x = self.fc(x)

        return x

class HTRDataset(Dataset):
    def __init__(self, csv_file, img_root_dir, char_to_idx, img_height=64, img_width=256, transform=None):
        self.data = pd.read_csv(csv_file)
        self.img_root_dir = img_root_dir
        self.img_height = img_height
        self.img_width = img_width
        self.char_to_idx = char_to_idx
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        image_rel_path = row['FILENAME']
        text = row['IDENTITY']

        image_path = os.path.join(self.img_root_dir, image_rel_path)
        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)

        image = resize_and_pad(image, self.img_height, self.img_width)

        image = image.astype('float32') / 255.0
        image = torch.tensor(image).unsqueeze(0)  # shape: [1, H, W]

        if self.transform:
            image = self.transform(image)

        label_encoded = [self.char_to_idx[c] for c in text if c in self.char_to_idx]
        label_tensor = torch.tensor(label_encoded, dtype=torch.long)

        return image, label_tensor, len(label_tensor), text


### Functions

In [None]:
def build_charset(csv_file):
    df = pd.read_csv(csv_file)
    charset = set()
    for text in df['IDENTITY']:
        charset.update(text)
    return sorted(list(charset))

def create_mapping(charset):
    char_to_idx = {char: idx + 1 for idx, char in enumerate(charset)}
    char_to_idx['<BLANK>'] = 0

    idx_to_char = {idx: char for char, idx in char_to_idx.items()}
    return char_to_idx, idx_to_char

def encode_label(text, char_to_idx):
    return [char_to_idx[char] for char in text if char in char_to_idx]

def decode_prediction(pred, idx_to_char):
    pred = pred.permute(1, 0, 2)
    pred_labels = torch.argmax(pred, dim=2)
    decoded = []
    for label_seq in pred_labels:
        prev = -1
        string = ''
        print(label_seq)
        for idx in label_seq:
            idx = idx.item()
            if idx != prev and idx != 0:
                string += idx_to_char.get(idx, '')
            prev = idx
        decoded.append(string)
    return decoded

def collate_fn(batch):
    images, labels, label_lengths, texts = zip(*batch)

    images = torch.stack(images)

    labels_concat = torch.cat(labels)
    label_lengths = torch.tensor(label_lengths, dtype=torch.long)

    return images, labels_concat, label_lengths, texts

def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0

    for images, labels, label_lengths, texts in tqdm(dataloader):
        images = images.to(device)
        labels = labels.to(device)
        label_lengths = label_lengths.to(device)

        outputs = model(images)
        T, B, C = outputs.size()

        input_lengths = torch.full(size=(B,), fill_value=T, dtype=torch.long).to(device)

        loss = criterion(outputs.log_softmax(2), labels, input_lengths, label_lengths)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        decoded = decode_prediction(outputs.cpu(), idx_to_char)

        total_loss += loss.item()

    return total_loss / len(dataloader)

def validate(model, dataloader, criterion, idx_to_char, device):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for images, labels, label_lengths, texts in tqdm(dataloader):
            images = images.to(device)
            labels = labels.to(device)
            label_lengths = label_lengths.to(device)

            outputs = model(images)
            T, B, C = outputs.size()
            input_lengths = torch.full(size=(B,), fill_value=T, dtype=torch.long).to(device)

            loss = criterion(outputs.log_softmax(2), labels, input_lengths, label_lengths)
            total_loss += loss.item()

            predictions = decode_prediction(outputs.cpu(), idx_to_char)

            for pred, true_text in zip(predictions, texts):
                if pred == true_text:
                    total_correct += 1
                total_samples += 1


    accuracy = total_correct / total_samples if total_samples > 0 else 0
    return total_loss / len(dataloader), accuracy

def resize_and_pad(image, target_height, target_width):
    h, w = image.shape
    scale = target_height / h
    new_w = int(w * scale)
    resized = cv2.resize(image, (new_w, target_height))

    if new_w < target_width:
        pad_w = target_width - new_w
        padded = cv2.copyMakeBorder(resized, 0, 0, 0, pad_w, cv2.BORDER_CONSTANT, value=255)
    else:
        padded = resized[:, :target_width]

    return padded

def predict_single_image(model, image_path, char_to_idx, idx_to_char, img_height=64, img_width=256, device='cpu'):
    model.eval()

    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    image = resize_and_pad(image, img_height, img_width)
    image = image.astype('float32') / 255.0
    image_tensor = torch.tensor(image).unsqueeze(0).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(image_tensor)

    predictions = decode_prediction(output.cpu(), idx_to_char)
    return predictions[0]

#### Used for prepare the dataset

In [None]:
'''
def extract_handwritten_part(image_path, ignore_ratio=0.09, printed_ratio=0.1):
    image = cv2.imread(image_path)
    height, width, _ = image.shape

    ignore_height = int(height * ignore_ratio)
    printed_height = int(height * printed_ratio)

    handwritten_text = image[ignore_height + printed_height + 50 : printed_height * 7 + 320, :]
    return handwritten_text

def save_image(image, path):
    cv2.imwrite(path, image)

def process_and_clean(images, source_folder, target_folder):
    for img_name in images:
        img_path = os.path.join(source_folder, img_name)
        new_path = os.path.join(target_folder, img_name)

        # Extract and save
        handwritten_img = extract_handwritten_part(img_path)
        save_image(handwritten_img, new_path)

        # Delete the original image
        os.remove(img_path)

def filter_labels(label_path, image_ids, save_path):
    with open(label_path, 'r') as infile, open(save_path, 'w') as outfile:
        for line in infile:
            image_id = line.split('\t')[0]
            if image_id in image_ids:
                outfile.write(line)

def get_image_ids(folder_path):
    return {os.path.splitext(f)[0] for f in os.listdir(folder_path) if f.endswith('.png')}
'''

# Prepare

### Directories

In [None]:
data_path = '/content/dataset/'
main_labels_path = data_path + 'labels.csv'
train_path       = data_path + 'train/'
val_path         = data_path + 'val/'
test_path        = data_path + 'test/'

filtered_labels = {
    'train': data_path + 'train_ds.csv',
    'val': data_path + 'validate_ds.csv',
    'test': data_path + 'test_ds.csv'
}

#### Charset

In [None]:
charset = build_charset(main_labels_path)
char_to_idx, idx_to_char = create_mapping(charset)

#### CNN model

In [None]:
cnn = nn.Sequential()
cnn.add_module('conv1', nn.Conv2d(1, 64, kernel_size=3, padding=1))
cnn.add_module('relu1', nn.ReLU())
cnn.add_module('pool1', nn.MaxPool2d(2, 2))

cnn.add_module('conv2', nn.Conv2d(64, 128, kernel_size=3, padding=1))
cnn.add_module('relu2', nn.ReLU())
cnn.add_module('pool2', nn.MaxPool2d(2, 2))

cnn.add_module('conv3', nn.Conv2d(128, 256, kernel_size=3, padding=1))
cnn.add_module('relu3', nn.ReLU())
cnn.add_module('conv4', nn.Conv2d(256, 256, kernel_size=3, padding=1))
cnn.add_module('relu4', nn.ReLU())
cnn.add_module('pool3', nn.MaxPool2d((2, 1), (2, 1)))

cnn.add_module('conv5', nn.Conv2d(256, 512, kernel_size=3, padding=1))
cnn.add_module('bn1', nn.BatchNorm2d(512))
cnn.add_module('relu5', nn.ReLU())
cnn.add_module('conv6', nn.Conv2d(512, 512, kernel_size=3, padding=1))
cnn.add_module('bn2', nn.BatchNorm2d(512))
cnn.add_module('relu6', nn.ReLU())
cnn.add_module('pool4', nn.MaxPool2d((2, 1), (2, 1)))

num_classes = len(char_to_idx)
img_height = 64

lstm = nn.LSTM(
    input_size=512 * (img_height // 16),
    hidden_size=256,
    num_layers=2,
    bidirectional=True,
    batch_first=False
)

fc = nn.Linear(512, num_classes)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CRNN(cnn, lstm, fc).to(device)

### Dataset

#### Data Augmentation

In [None]:
train_transforms = transforms.Compose([
    transforms.RandomAffine(degrees=4, translate=(0.02, 0.02), scale=(0.95, 1.05), shear=4),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.5)),
    transforms.RandomErasing(p=0.1, scale=(0.01, 0.03), ratio=(0.3, 3.3), value='random'),
])

#### Train and Validation

In [None]:
train_dataset = HTRDataset(filtered_labels['train'], train_path, char_to_idx, transform=train_transforms)
val_dataset = HTRDataset( filtered_labels['val'], val_path, char_to_idx)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

#### Loss function and optimizer

In [None]:
ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

## Train the model

In [None]:
epochs = 100
for epoch in range(1, epochs + 1):
    print(f"\nEpoch {epoch}/{epochs}")
    train_loss = train_one_epoch(model, train_loader, ctc_loss, optimizer, device)
    val_loss, val_acc = validate(model, val_loader, ctc_loss, idx_to_char, device)

    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Accuracy: {val_acc:.4f}")

### Single Prediction

In [None]:
image_path = test_path + 'TEST_0007.jpg'
predicted_text = predict_single_image(model, image_path, char_to_idx, idx_to_char, device=device)

print("Predicted Text:", predicted_text)