In [None]:
CRNN+GRAPHEMIZER+BnHTR.ipynb

In [1]:
import os, json
from PIL import Image
import numpy as np

import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.nn import CTCLoss

In [2]:

from data.dataset import TextDataset as TDataset
from data.data_utils import collate_fn
from trainer.train import train
from trainer.sequence_decoder import ctc_decode
from modeling.model_utils import load_model

from data.custom_sampler import CustomDatasetSampler
from configs.config_crnn import train_config
from configs.dataconfig import (
    train_source, val_source, mapper, test_sources
)
from utils.augment import Augmentation

In [3]:
os.makedirs(train_config['checkpoints_dir'], exist_ok=True)

In [4]:
def load_backbone(model, saved_path):
    current_model_dict = model.state_dict()
    loaded_state_dict = torch.load(saved_path)
    loaded_state_dict = loaded_state_dict['state_dict']

    new_state_dict= {
        k: v if v.size() == current_model_dict[k].size()  
        else  current_model_dict[k] 
        for k, v in zip(current_model_dict.keys(), loaded_state_dict.values())
    }

    mis_matched_layers = [
        k for k,v in zip(current_model_dict.keys(), loaded_state_dict.values())
        if v.size() != current_model_dict[k].size()
    ]

    if mis_matched_layers:
        print(f"{len(mis_matched_layers)} layers found.")
        print(mis_matched_layers)  
   
    model.load_state_dict(new_state_dict, strict=True)

    print('model loaded successfully')
    return model

In [6]:
def define_crnn_model(cfg, num_class, reload_checkpoint = ''):
    from modeling.crnn import CRNN
    config = cfg

    crnn = CRNN(
        1, 
        num_class,
        map_to_seq_hidden=config['map_to_seq_hidden'],
        rnn_hidden=config['rnn_hidden'],
        leaky_relu=config['leaky_relu']
    )

    if reload_checkpoint:
        crnn = load_model(crnn, reload_checkpoint)
        print('model loaded successfully')

    return crnn

In [7]:
from BnTokenizer import TrieTokenizer
from BnTokenizer.base import BnGraphemizer

tokenizer = BnGraphemizer(
    tokenizer_class=TrieTokenizer,
    max_len=64,
    normalize_unicode=True,
    normalization_mode='NFKC',
    normalizer="unicode",
    printer=print
)

graphemes = json.load(open("graphemes.json", 'r'))
tokenizer.add_tokens(graphemes,reset_oov=True)

Selected Tokenizer: TrieTokenizer
Max Sequence Length: 64
Normalize Text: True
Normalizar: unicode
Normalization Mode: NFKC
update completed.[2143] new vocabs added. Current vocab count: 2145


In [None]:
tokenizer.save_vocab(f"{train_config['checkpoints_dir']}/tokenizer_vocab.json")

In [8]:
from data.data_source_controller import DataSourceController

process_text = lambda x : x.replace('\u200c','').replace("\u200d", '')

train_data = DataSourceController(
    filter=lambda x: len(x.label)<30, 
    transform= process_text
)
for k  in [
    'boise_camera_train', 
    'boise_scan_train', 
    'boise_conjunct_train', 
    'syn_boise_conjunct_train',
    # 'bn_grapheme_train',
    # 'syn_train'
    # "bangla_writting_train"
    # "bn_htr_train"

]:
    train_data.add_data(**train_source[k])

Out of 21026 boise_camera_train,21026 are kept after filtering
Total data 21026
Out of 20367 boise_scan_train,20367 are kept after filtering
Total data 41393
Out of 5798 boise_conjunct_train,5798 are kept after filtering
Total data 47191
Out of 25000 syn_boise_conjunct_train,5000 are kept after filtering
Total data 52191


In [9]:
val_data = DataSourceController(
    filter=lambda x: len(x.label)<30,
    transform= process_text
)
for k in[
    'boise_camera_val',
    'boise_scan_val', 
    'boise_conjunct_val', 
    # 'syn_val',
    # 'syn_boise_conjunct_val',
    # #'bn_grapheme_val',
    # "bangla_writting_val",
    # "bn_htr_val"

]:
    val_data.add_data(**val_source[k])

Out of 2630 boise_camera_val,2630 are kept after filtering
Total data 2630
Out of 2620 boise_scan_val,2620 are kept after filtering
Total data 5250
Out of 824 boise_conjunct_val,824 are kept after filtering
Total data 6074


In [10]:
train_dataset = TDataset(
    train_data.data,
    tokenizer,
    img_height= 32,
    img_width= 128,
    noiseAugment= Augmentation(.50)
)

# sampler=CustomDatasetSampler(
#     train_dataset,
#     num_samples = train_config['max_sample_per_epoch'],
#     labels = [mapper[d.id] for d in data.data], 
# ) 

dataloader = DataLoader(train_dataset, 
        batch_size= train_config['train_batch_size'], 
        collate_fn=collate_fn, 
        shuffle = True,#train_config['shuffle'],
        prefetch_factor = train_config['prefetch_factor'],
        num_workers = train_config['cpu_workers'],
        #sampler = sampler
)
len(dataloader)

Total 52191 Images found!!!


408

In [11]:
val_dataset = TDataset(
    val_data.data,
    tokenizer,
    img_height= 32,
    img_width= 128,
    #noiseAugment= NoiseAugment()
)

# val_sampler=CustomDatasetSampler(
#     val_dataset,
#     num_samples = 12800,
#     labels = [mapper[d.id] for d in val_data.data]
# ) 

val_dataloader = DataLoader(val_dataset, 
        batch_size= train_config['train_batch_size'], 
        collate_fn=collate_fn, 
        prefetch_factor = train_config['prefetch_factor'],
        num_workers = train_config['cpu_workers'],
        #sampler = val_sampler
        
    )
len(val_dataloader)

Total 6074 Images found!!!


48

In [None]:
train_config['max_iter'] = 170
train_config['epochs'] = 60

device = torch.device(
    'cuda' if torch.cuda.is_available() else 'cpu'
)

model = define_crnn_model(train_config, len(tokenizer.vocab) + 1)
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=train_config['lr'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode = 'min', 
    factor = .4, 
    patience = 5,
    verbose= True,
    min_lr= 0.000001
)

criterion = CTCLoss(
    reduction='sum', 
    zero_infinity=True
).to(device)

In [None]:
model = train(
    train_config, model, optimizer,scheduler,
    criterion, device, dataloader, val_dataloader
)

In [None]:
from trainer.evaluate import evaluate
evaluation = evaluate(
    model, 
    val_dataloader, 
    criterion,
    decode_method=train_config['decode_method'],
    beam_size=train_config['beam_size'],
    max_iter= train_config['max_iter'] 
)

In [None]:
checkpoint =  torch.load(
    "artifacts/crnn/CRNN+GRAPHEMIZER+BTHR+Boise/crnn_044500_loss_0.8612_acc_0.8996.pt"
)
model.load_state_dict(checkpoint['model'])
model.to(device)
model.eval()

In [None]:
def inference(
    cfg, model, inf_loader,tokenizer,
    decode_method='beam_search',
    beam_size=10,
    save_image=True
):
    
    import pandas as pd
    from utils.utils import levenshtein_distance
    from trainer.evaluate import predict



    fun = lambda x: ''.join([tokenizer.vocab[i] for i in x]).replace('<oov>', '‚ñÅ')

    prediciton_info = predict(
        model, inf_loader,tokenizer,
        decode_method='beam_search',
        beam_size=beam_size
    )
    (all_gts, all_preds, wrong_cases, who_are_we) = prediciton_info

    report = pd.DataFrame.from_dict((
        {
            'GroundTruth':list(map(fun,all_gts)),
            'Prediction': list(map(fun,all_preds))
        }
    ))

    report['Edit Distance'] = list(
        map(lambda x: levenshtein_distance(*x,True),
        zip(report['GroundTruth'], report['Prediction']))
    )
    report['GT Length'] = [len(i) for i in report['GroundTruth']]
    report['Split'] = [i.split('|')[0] for i in who_are_we]
    report['Path'] =  [i.split('|')[-1] for i in who_are_we]
    
    saving_dir = f"{cfg['checkpoints_dir']}/report.checkpoints.{who_are_we[0].split('|')[0]}.test.xlsx"
    report.to_excel(f"{cfg['checkpoints_dir']}/report.checkpoints.{who_are_we[0].split('|')[0]}.test.textonly.xlsx")

    if save_image:
        writer = pd.ExcelWriter(saving_dir, engine='xlsxwriter')
        report.to_excel(writer, sheet_name='Sheet1')
        workbook  = writer.book
        worksheet = writer.sheets['Sheet1']

        i = 2
        for img in report['Path']:
            worksheet.insert_image(f'H{i}', img)
            i += 1
        writer.save()

    return report

In [None]:
test_sources.update(
   {
        "boise_camera_test": {
            'data': '/home/jahid/Music/bn_dataset/boiseState/camera/split/test_annotaion.json',
            'base_dir': '/home/jahid/Music/bn_dataset/boiseState/camera/split/test_crop_images',
            'id': 'boise_camera_test'
        },
        "boise_scan_test": {
            'data': '/home/jahid/Music/bn_dataset/boiseState/scan/split/test_annotaion.json',
            'base_dir': '/home/jahid/Music/bn_dataset/boiseState/scan/split/test_crop_images',
            'id': 'boise_scan_test'
        },
        "boise_conjunct_test": {
            'data': '/home/jahid/Music/bn_dataset/boiseState/conjunct/split/test_annotaion.json',
            'base_dir': '/home/jahid/Music/bn_dataset/boiseState/conjunct/split/test_crop_images',
            'id': 'boise_conjunct_test'
        }
    }
)
test_sources

In [None]:
from data.data_source_controller import DataSourceController
for k, v in test_sources.items():
    if k not in  [
        'bn_htr_test',
        'boise_scan_test',
        'boise_camera_test',
        'boise_conjunct_test'
    ]: 
        continue
    val_data = DataSourceController(filter=lambda x: len(x.label)<30, transform= process_text)

    print(k,v.get('n'))
    #continue
    val_data.add_data(**v)

    val_dataset = TDataset(
        val_data.data,
        tokenizer,
        img_height= 32,
        img_width= 128
    )

    val_dataloader = DataLoader(
        val_dataset, 
        batch_size= train_config['train_batch_size'], 
        collate_fn=collate_fn, 
        prefetch_factor = 1,
        num_workers = 4
    )
    report = inference(train_config, model, val_dataloader,tokenizer, save_image=True)

In [None]:
import glob
import pandas as pd
for xl in sorted(glob.glob(f"{train_config['checkpoints_dir']}/*only.xlsx")):
    _report = pd.read_excel(xl, engine='openpyxl').fillna('')
    _report = _report[_report['Edit Distance']>0]

    writer = pd.ExcelWriter(xl.replace('.xlsx', '.error.xlsx'), engine='xlsxwriter')
    _report.to_excel(writer, sheet_name='Sheet1')
    workbook  = writer.book
    worksheet = writer.sheets['Sheet1']

    i = 2
    for img in _report['Path']:
        worksheet.insert_image(f'K{i}', img)
        i += 1
    writer.save()

In [None]:
import torchmetrics
import pandas as pd
import glob


process_text = lambda x : x.replace('\u200c','')
for xl in sorted(glob.glob(f"{train_config['checkpoints_dir']}/*only.xlsx")):
    _report = pd.read_excel(xl).fillna('')
    try:
        _report['GroundTruth'] = _report['GroundTruth'].apply(process_text)
        _report['Prediction'] = _report['Prediction'].apply(process_text)
        print(xl.split('/')[-1])
        print("    Char Error Rate",torchmetrics.CharErrorRate()(_report['GroundTruth'], _report['Prediction']))
        print("    Word Error Rate",sum(_report['GroundTruth'] !=_report['Prediction'])/len(_report))
        print("    Word Accuracy",sum(_report['GroundTruth'] ==_report['Prediction'])/len(_report))
    except Exception as e:
        print(e)
        print(xl.split('/')[-1])

In [None]:

%%writefile example.txt


In [None]:
import torchmetrics
import pandas as pd
import glob


process_text = lambda x : x.replace('\u200c','')
for xl in sorted(glob.glob(f"artifacts/crnn/CRNN+GRAPHEMIZER+BnHTR/*only.xlsx")):
    _report = pd.read_excel(xl).fillna('')
    try:
        _report['GroundTruth'] = _report['GroundTruth'].apply(process_text)
        _report['Prediction'] = _report['Prediction'].apply(process_text)
        print(xl.split('/')[-1])
        print("    Char Error Rate",torchmetrics.CharErrorRate()(_report['GroundTruth'], _report['Prediction']))
        print("    Word Accuracy",sum(_report['GroundTruth']==_report['Prediction'])/len(_report))
    except Exception as e:
        print(e)
        print(xl.split('/')[-1])

In [None]:
model.eval()


def paste_in_the_middle( image: Image, canvas: Image):
    cw, ch = canvas.size
    image.thumbnail((cw, ch), Image.ANTIALIAS)
    w, h = image.size

    if w < cw:
        canvas.paste(image, ((cw-w)//2,0)) 
    elif h < ch:
        canvas.paste(image, (0,(ch-h)//2))
    elif w ==cw and h == ch:
        canvas = image
    else:
        image.resize((cw, ch))
        canvas = image
    return canvas

idx = 11

all_images = glob.glob("/home/jahid/Downloads/bn_dataset/ocr_hw_data/data/*.jpg")

op = []

#for i in tqdm(range(0,len(all_images)//15, 128)):

images = all_images[0*64:0*128+64]
test_images = []
for i in range(len(images)):
    image = Image.open(images[i]).convert('L')


    canvas = Image.new('L', (128, 32), color = 'white')
    image = paste_in_the_middle(image.copy(), canvas.copy())

    image = np.array(image)
    image = np.expand_dims(image, axis=0)
    image = (image / 127.5) - 1.0
    test_images.append(image)

test_images = torch.FloatTensor(np.array(test_images))
#test_images = torch.unsqueeze(test_images, dim=0)

with torch.no_grad():
    logits = model(test_images.to('cuda'))
    log_probs = torch.nn.functional.log_softmax(logits, dim=2)

batch_size = test_images.size(0)

input_lengths = torch.LongTensor([logits.size(0)] * batch_size)
preds = ctc_decode(log_probs.detach(), method='beam_search', beam_size=10)
op.extend(tokenizer.ids_to_text(preds))

In [None]:
list(zip(images,op))

In [None]:
import torchmetrics
import pandas as pd
import glob


process_text = lambda x : x.replace('\u200c','')
for xl in sorted(glob.glob(f"artifacts/crnn/CRNN+GRAPHEMIZER+Boise/*only.xlsx")):
    _report = pd.read_excel(xl).fillna('')
    try:
        _report['GroundTruth'] = _report['GroundTruth'].apply(process_text)
        _report['Prediction'] = _report['Prediction'].apply(process_text)
        print(xl.split('/')[-1])
        print("    Char Error Rate",torchmetrics.CharErrorRate()(_report['GroundTruth'], _report['Prediction']))
        print("    Word Error Rate",sum(_report['GroundTruth'] !=_report['Prediction'])/len(_report))
        print("    Word Accuracy",sum(_report['GroundTruth'] ==_report['Prediction'])/len(_report))
    except Exception as e:
        print(e)
        print(xl.split('/')[-1])

In [None]:
import torchmetrics
import pandas as pd
import glob


process_text = lambda x : x.replace('\u200c','')
for xl in sorted(glob.glob(f"artifacts/crnn/CRNN+GRAPHEMIZER+Boise+char/*only.xlsx")):
    _report = pd.read_excel(xl).fillna('')
    try:
        _report['GroundTruth'] = _report['GroundTruth'].apply(process_text)
        _report['Prediction'] = _report['Prediction'].apply(process_text)
        print(xl.split('/')[-1])
        print("    Char Error Rate",torchmetrics.CharErrorRate()(_report['GroundTruth'], _report['Prediction']))
        print("    Word Error Rate",sum(_report['GroundTruth'] !=_report['Prediction'])/len(_report))
        print("    Word Accuracy",sum(_report['GroundTruth'] ==_report['Prediction'])/len(_report))
    except Exception as e:
        print(e)
        print(xl.split('/')[-1])

In [None]:
import torchmetrics
import pandas as pd
import glob


process_text = lambda x : x.replace('\u200c','')
for xl in sorted(glob.glob(f"artifacts/crnn/CRNN+GRAPHEMIZER+BTHR/*only.xlsx")):
    _report = pd.read_excel(xl).fillna('')
    try:
        _report['GroundTruth'] = _report['GroundTruth'].apply(process_text)
        _report['Prediction'] = _report['Prediction'].apply(process_text)
        print(xl.split('/')[-1])
        print("    Char Error Rate",torchmetrics.CharErrorRate()(_report['GroundTruth'], _report['Prediction']))
        print("    Word Error Rate",sum(_report['GroundTruth'] !=_report['Prediction'])/len(_report))
        print("    Word Accuracy",sum(_report['GroundTruth'] ==_report['Prediction'])/len(_report))
    except Exception as e:
        print(e)
        print(xl.split('/')[-1])

In [None]:
import torchmetrics
import pandas as pd
import glob


process_text = lambda x : x.replace('\u200c','')
for xl in sorted(glob.glob(f"artifacts/crnn/CRNN+GRAPHEMIZER+BTHR+char/*only.xlsx")):
    _report = pd.read_excel(xl).fillna('')
    try:
        _report['GroundTruth'] = _report['GroundTruth'].apply(process_text)
        _report['Prediction'] = _report['Prediction'].apply(process_text)
        print(xl.split('/')[-1])
        print("    Char Error Rate",torchmetrics.CharErrorRate()(_report['GroundTruth'], _report['Prediction']))
        print("    Word Error Rate",sum(_report['GroundTruth'] !=_report['Prediction'])/len(_report))
        print("    Word Accuracy",sum(_report['GroundTruth'] ==_report['Prediction'])/len(_report))
    except Exception as e:
        print(e)
        print(xl.split('/')[-1])