# Experiment with Pytorch
to create endocder decoder ocr model with cnn and lstm

In [2]:
# Setup path in .env file
import os
from dotenv import load_dotenv
import cv2
import numpy as np

# Get absolut path to proeject root
load_dotenv()
data_path = os.getenv("PUBTABNET_DATA_DIR")

project_root_dir = os.path.dirname(os.path.abspath("./"))
print("Project root dir:", project_root_dir)

data_dir = os.getenv("PUBTABNET_DATA_DIR")
absolute_dir = project_root_dir + data_dir
print("Absolute path:", absolute_dir)

Project root dir: /Users/leonremke/Documents/GIT_REPOS/UNI/neural_networks_seminar
Absolute path: /Users/leonremke/Documents/GIT_REPOS/UNI/neural_networks_seminar/pubtabnet


In [3]:
image_dir = f"{absolute_dir}/train"
label_file = f"{absolute_dir}/PubTabNet_2.0.0.jsonl"
output_file = f"{absolute_dir}/subset_val.jsonl"
subset_size =10  # Number of entries in the subset

In [4]:
import sys
sys.path.append('../')
from utils.paddle_dataset import PaddleOCRDataset, ResizeNormalize
from torch.utils.data import DataLoader
transform = ResizeNormalize(size=(256, 256))
label_file_small = f"{absolute_dir}/subset_small.json"
train_dataset = PaddleOCRDataset(image_dir=image_dir, label_file=label_file_small, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)


Image names initialised
Label data loaded


In [5]:
def create_alphabet_file(alphabet, file_path):
    with open(file_path, 'w') as f:
        f.write("START\n")
        for char in alphabet:
            f.write(char + "\n")
        f.write("END\n")

# Replace these with your actual ground truth labels
gt_labels = train_loader.dataset.labels
alph_labels = []
for label in gt_labels:
    label_tokens = [cell["tokens"] for cell in label["html"]["cells"]]
    label = " ".join(" ".join(tokens) for tokens in label_tokens)
    label = label.replace("<b>", " ").replace("</b>", " ")
    alph_labels.append(label)
print("Ground truth labels loaded: ", len(alph_labels), alph_labels)

unique_chars = set()
unique_radicals = set()

# Loop through ground truth labels to extract unique characters and radicals
for label in alph_labels:
    for char in label:
        unique_chars.add(char)
        # You might need to extract radicals from each character here if using decomposition

# Define the file paths for character and radical alphabets
char_alphabet_file_path = f"{absolute_dir}/character_alphabet.txt"
radical_alphabet_file_path = f"{absolute_dir}/radical_alphabet.txt"

# Create character alphabet file
create_alphabet_file(unique_chars, char_alphabet_file_path)
print(f"Character alphabet file created at: {char_alphabet_file_path}")

# Create radical alphabet file (if needed)
# create_alphabet_file(unique_radicals, radical_alphabet_file_path)
# print(f"Radical alphabet file created at: {radical_alphabet_file_path}")

Ground truth labels loaded:  10 ['  S p e c i e s     A n a j a ́ s     P o r t e l     S S B V                 W i l d     R u r a l     U r b a n     W i l d     R u r a l     U r b a n     W i l d     R u r a l     U r b a n     T o t a l     ( % )   <i> E v a n d r o m y i a   w a l k e r i </i> 4 4 4 0 1 1 5 6 2 9 6 1 5 1 1 2 5 9 2 6 8 . 8 4 <i> E v a n d r o m y i a   i n f r a s p i n o s a </i> 4 4 0 3 8 2 1 0 0 0 0 1 3 0 1 5 . 1 2 <i> N y s s o m y i a   a n t u n e s i <sup> a </sup> </i> 1 1 3 3 2 0 3 0 1 0 0 4 1 4 . 7 7 <i> M i c r o p y g o m y i a   r o r o t a e n s i s </i> 2 0 1 0 4 0 0 2 0 0 2 7 3 . 1 4 <i> S c i o p e m y i a   s o r d e l l i i </i> 7 1 0 1 3 2 0 2 0 0 2 5 2 . 9 1 <i> B i c h r o m o m y i a   f l a v i s c u t e l l a t a <sup> a </sup> </i> 0 0 0 4 0 0 1 6 0 0 2 0 2 . 3 2 <i> N y s s o m y i a   y u i l l i   y u i l l i </i> 4 0 0 0 0 0 0 0 0 4 0 . 4 6 <i> P s a t h y r o m y i a   a r a g a o i </i> 2 0 0 2 0 0 0 0 0 4 0 . 4 6 <i> P s a t h y r 

In [7]:
# https://deepayan137.github.io/blog/markdown/2020/08/29/building-ocr.html#the-ctc-loss
from torch import nn
class BidirectionalLSTM(nn.Module):

    def __init__(self, nIn, nHidden, nOut):
        super(BidirectionalLSTM, self).__init__()
        self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
        self.embedding = nn.Linear(nHidden * 2, nOut)
    def forward(self, input):
        self.rnn.flatten_parameters()
        recurrent, _ = self.rnn(input)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)
        output = self.embedding(t_rec)  # [T * b, nOut]
        output = output.view(T, b, -1)
        return output

class CRNN(nn.Module):

    def __init__(self, opt, leakyRelu=False):
        super(CRNN, self).__init__()

        assert opt['imgH'] % 16 == 0, 'imgH has to be a multiple of 16'

        ks = [3, 3, 3, 3, 3, 3, 2]
        ps = [1, 1, 1, 1, 1, 1, 0]
        ss = [1, 1, 1, 1, 1, 1, 1]
        nm = [64, 128, 256, 256, 512, 512, 512]

        cnn = nn.Sequential()

        def convRelu(i, batchNormalization=False):
            nIn = opt['nChannels'] if i == 0 else nm[i - 1]
            nOut = nm[i]
            cnn.add_module('conv{0}'.format(i),
                           nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
            if batchNormalization:
                cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
            if leakyRelu:
                cnn.add_module('relu{0}'.format(i),
                               nn.LeakyReLU(0.2, inplace=True))
            else:
                cnn.add_module('relu{0}'.format(i), nn.ReLU(True))

        convRelu(0)
        cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2))  # 64x16x64
        convRelu(1)
        cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2))  # 128x8x32
        convRelu(2, True)
        convRelu(3)
        cnn.add_module('pooling{0}'.format(2),
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 256x4x16
        convRelu(4, True)
        convRelu(5)
        cnn.add_module('pooling{0}'.format(3),
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 512x2x16
        convRelu(6, True)  # 512x1x16
        self.cnn = cnn
        self.rnn = nn.Sequential()
        self.rnn = nn.Sequential(
            BidirectionalLSTM(opt['nHidden']*2, opt['nHidden'], opt['nHidden']),
            BidirectionalLSTM(opt['nHidden'], opt['nHidden'], opt['nClasses']))


    def forward(self, input):
        # conv features
        conv = self.cnn(input)
        b, c, h, w = conv.size()
        assert h == 1, "the height of conv must be 1"
        conv = conv.squeeze(2)
        conv = conv.permute(2, 0, 1)  # [w, b, c]
        # rnn features
        output = self.rnn(conv)
        output = output.transpose(1,0) #Tbh to bth
        return output

In [8]:
import torch

class CNN_Encoder(nn.Module):

    def __init__(self, params, leakyRelu=False):
        super(CNN_LSTM_OCR, self).__init__()
        self.input_dim = params.input_dim
        self.hidden_dim = params.hidden_dim
        self.output_dim = params.output_dim
        self.input_planes = params.input_planes
        self.planes = params.planes

        # Define the CNN layers
        # Use 1x1 convolutions for the remaining layers
        self.conv_layer_1 = nn.Conv2d(self.input_planes, self.planes, kernel_size=1, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.planes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        # Use 3x3 convolutions for the remaining layers
        self.conv_layer_2 = nn.Conv2d(self.input_planes, self.planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.maxpool_2 = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        # Apply the CNN layers
        out = self.conv_layer_1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.maxpool_1(out)
        out = self.conv_layer_2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.maxpool_2(out)
        return out


In [9]:
import torch

class LSTM_Decoder(nn.Module):

    def __init__(self, params, leakyRelu=False):
        super(LSTM_Decoder, self).__init__()
        self.input_dim = params.input_dim
        self.hidden_dim = params.hidden_dim
        self.output_dim = params.output_dim

        # Define the LSTM layers
        self.lstm = nn.LSTM(nIn, nHidden, bidirectional=True)
        self.embedding = nn.Linear(self.hidden_dim * 2, self.output_dim)

    def forward(self, x):
        # Apply the LSTM layers
        output = self.lstm(x)
        output = output.transpose(1,0) #Tbh to bth
        return output


In [None]:
class CNNLSTM_OCR(nn.Module):
    def __init__(self, params):
        super(CNNLSTM_OCR, self).__init__()
        self.cnn_encoder = CNN_Encoder(params)
        self.lstm_decoder = LSTM_Decoder(params)
    
    def forward(self, x):
        # Apply the CNN encoder
        out = self.cnn_encoder(x)
        # Apply the LSTM decoder
        out = self.lstm_decoder(out)
        return out


In [10]:
# From https://deepayan137.github.io/blog/markdown/2020/08/29/building-ocr.html#the-ctc-loss
class CustomCTCLoss(torch.nn.Module):
    # T x B x H => Softmax on dimension 2
    def __init__(self, dim=2):
        super().__init__()
        self.dim = dim
        self.ctc_loss = torch.nn.CTCLoss(reduction='mean', zero_infinity=True)

    def forward(self, logits, labels,
            prediction_sizes, target_sizes):
        EPS = 1e-7
        loss = self.ctc_loss(logits, labels, prediction_sizes, target_sizes)
        loss = self.sanitize(loss)
        return self.debug(loss, logits, labels, prediction_sizes, target_sizes)
    
    def sanitize(self, loss):
        EPS = 1e-7
        if abs(loss.item() - float('inf')) < EPS:
            return torch.zeros_like(loss)
        if math.isnan(loss.item()):
            return torch.zeros_like(loss)
        return loss

    def debug(self, loss, logits, labels,
            prediction_sizes, target_sizes):
        if math.isnan(loss.item()):
            print("Loss:", loss)
            print("logits:", logits)
            print("labels:", labels)
            print("prediction_sizes:", prediction_sizes)
            print("target_sizes:", target_sizes)
            raise Exception("NaN loss obtained. But why?")
        return loss

In [None]:
# Setup environment for training
import os
import sys
import pdb
import six
import random
import lmdb
from PIL import Image
import numpy as np
import math
from collections import OrderedDict
from itertools import chain
import logging


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import sampler
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.utils.data import random_split
from tqdm import *

params = {
    "input_dim": 1,
    "hidden_dim": 1,
    "output_dim": 1,
    "input_planes": 1,
    "planes": 1,
    "schedule": False,

    'image_height':32,
    'number_channels':1,
    'number_hidden_layers':256,
    'len_alphabet':len(alphabet),
    'learning_rate':0.001,
    'epochs':4,
    'batch_size':32,
    'model_dir':'model_history',
    'log_dir':'logs',
    'resume':False,
    'cuda':False,
    'schedule':False    
}


model = CNNLSTM_OCR(params)
criterion = CustomCTCLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])
scheduler = CosineAnnealingLR(optimizer, T_max=opt['epochs'])
batch_size = params['batch_size']
count = 1
epochs = params['epochs']
cuda = params['cuda']

def init_meters(self):
    self.avgTrainLoss = AverageMeter("Train loss")
    self.avgTrainCharAccuracy = AverageMeter("Train Character Accuracy")
    self.avgTrainWordAccuracy = AverageMeter("Train Word Accuracy")
    self.avgValLoss = AverageMeter("Validation loss")
    self.avgValCharAccuracy = AverageMeter("Validation Character Accuracy")
    self.avgValWordAccuracy = AverageMeter("Validation Word Accuracy")

def forward(self, x):
    logits = self.model(x)
    return logits.transpose(1, 0)

def loss_fn(self, logits, targets, pred_sizes, target_sizes):
    loss = self.criterion(logits, targets, pred_sizes, target_sizes)
    return loss

def step(self):
    self.max_grad_norm = 0.05
    clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
    self.optimizer.step()

def schedule_lr(self):
    if self.schedule:
        self.scheduler.step()

def _run_batch(self, batch, report_accuracy=False, validation=False):
    input_, targets = batch['img'], batch['label']
    targets, lengths = self.converter.encode(targets)
    logits = self.forward(input_)
    logits = logits.contiguous().cpu()
    logits = torch.nn.functional.log_softmax(logits, 2)
    T, B, H = logits.size()
    pred_sizes = torch.LongTensor([T for i in range(B)])
    targets= targets.view(-1).contiguous()
    loss = self.loss_fn(logits, targets, pred_sizes, lengths)
    if report_accuracy:
        probs, preds = logits.max(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        sim_preds = self.converter.decode(preds.data, pred_sizes.data, raw=False)
        ca = np.mean((list(map(self.evaluator.char_accuracy, list(zip(sim_preds, batch['label']))))))
        wa = np.mean((list(map(self.evaluator.word_accuracy, list(zip(sim_preds, batch['label']))))))
    return loss, ca, wa

def run_epoch(self, validation=False):
    if not validation:
        loader = self.train_dataloader()
        pbar = tqdm(loader, desc='Epoch: [%d]/[%d] Training'%(self.count, 
            self.epochs), leave=True)
        self.model.train()
    else:
        loader = self.val_dataloader()
        pbar = tqdm(loader, desc='Validating', leave=True)
        self.model.eval()
    outputs = []
    for batch_nb, batch in enumerate(pbar):
        if not validation:
            output = self.training_step(batch)
        else:
            output = self.validation_step(batch)
        pbar.set_postfix(output)
        outputs.append(output)
    self.schedule_lr()
    if not validation:
        result = self.train_end(outputs)
    else:
        result = self.validation_end(outputs)
    return result

def training_step(self, batch):
    loss, ca, wa = self._run_batch(batch, report_accuracy=True)
    self.optimizer.zero_grad()
    loss.backward()
    self.step()
    output = OrderedDict({
        'loss': abs(loss.item()),
        'train_ca': ca.item(),
        'train_wa': wa.item()
        })
    return output

def validation_step(self, batch):
    loss, ca, wa = self._run_batch(batch, report_accuracy=True, validation=True)
    output = OrderedDict({
        'val_loss': abs(loss.item()),
        'val_ca': ca.item(),
        'val_wa': wa.item()
        })
    return output

def train_dataloader(self):
    # logging.info('training data loader called')
    loader = torch.utils.data.DataLoader(self.data_train,
            batch_size=self.batch_size,
            collate_fn=self.collate_fn,
            shuffle=True)
    return loader
    
def val_dataloader(self):
    # logging.info('val data loader called')
    loader = torch.utils.data.DataLoader(self.data_val,
            batch_size=self.batch_size,
            collate_fn=self.collate_fn)
    return loader

def train_end(self, outputs):
    for output in outputs:
        self.avgTrainLoss.add(output['loss'])
        self.avgTrainCharAccuracy.add(output['train_ca'])
        self.avgTrainWordAccuracy.add(output['train_wa'])

    train_loss_mean = abs(self.avgTrainLoss.compute())
    train_ca_mean = self.avgTrainCharAccuracy.compute()
    train_wa_mean = self.avgTrainWordAccuracy.compute()

    result = {'train_loss': train_loss_mean, 'train_ca': train_ca_mean,
    'train_wa': train_wa_mean}
    # result = {'progress_bar': tqdm_dict, 'log': tqdm_dict, 'val_loss': train_loss_mean}
    return result

def validation_end(self, outputs):
    for output in outputs:
        self.avgValLoss.add(output['val_loss'])
        self.avgValCharAccuracy.add(output['val_ca'])
        self.avgValWordAccuracy.add(output['val_wa'])

    val_loss_mean = abs(self.avgValLoss.compute())
    val_ca_mean = self.avgValCharAccuracy.compute()
    val_wa_mean = self.avgValWordAccuracy.compute()

    result = {'val_loss': val_loss_mean, 'val_ca': val_ca_mean,
    'val_wa': val_wa_mean}
    return result