In [1]:
import os, json
from PIL import Image
import numpy as np

import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.nn import CTCLoss

In [2]:

from data.dataset import TextDataset as TDataset
from data.data_utils import collate_fn
from trainer.train import train
from trainer.sequence_decoder import ctc_decode
from modeling.model_utils import load_model

from data.custom_sampler import CustomDatasetSampler
from configs.config_crnn import train_config
from configs.dataconfig import (
    train_source, val_source, mapper, test_sources
)
from utils.augment import Augmentation

In [3]:
os.makedirs(train_config['checkpoints_dir'], exist_ok=True)

In [4]:
def load_backbone(model, saved_path):
    current_model_dict = model.state_dict()
    loaded_state_dict = torch.load(saved_path)
    loaded_state_dict = loaded_state_dict['state_dict']

    new_state_dict= {
        k: v if v.size() == current_model_dict[k].size()  
        else  current_model_dict[k] 
        for k, v in zip(current_model_dict.keys(), loaded_state_dict.values())
    }

    mis_matched_layers = [
        k for k,v in zip(current_model_dict.keys(), loaded_state_dict.values())
        if v.size() != current_model_dict[k].size()
    ]

    if mis_matched_layers:
        print(f"{len(mis_matched_layers)} layers found.")
        print(mis_matched_layers)  
   
    model.load_state_dict(new_state_dict, strict=True)

    print('model loaded successfully')
    return model

In [6]:
def define_crnn_model(cfg, num_class, reload_checkpoint = ''):
    from modeling.crnn import CRNN
    config = cfg

    crnn = CRNN(
        1, 
        num_class,
        map_to_seq_hidden=config['map_to_seq_hidden'],
        rnn_hidden=config['rnn_hidden'],
        leaky_relu=config['leaky_relu']
    )

    if reload_checkpoint:
        crnn = load_model(crnn, reload_checkpoint)
        print('model loaded successfully')

    return crnn

In [7]:
from BnTokenizer import TrieTokenizer
from BnTokenizer.base import BnGraphemizer

tokenizer = BnGraphemizer(
    tokenizer_class=TrieTokenizer,
    max_len=64,
    normalize_unicode=True,
    normalization_mode='NFKC',
    normalizer="unicode",
    printer=print
)

graphemes = json.load(open("graphemes.json", 'r'))
tokenizer.add_tokens(graphemes,reset_oov=True)

Selected Tokenizer: TrieTokenizer
Max Sequence Length: 64
Normalize Text: True
Normalizar: unicode
Normalization Mode: NFKC
update completed.[2143] new vocabs added. Current vocab count: 2145


In [None]:
tokenizer.save_vocab(f"{train_config['checkpoints_dir']}/tokenizer_vocab.json")

In [8]:
from data.data_source_controller import DataSourceController

process_text = lambda x : x.replace('\u200c','').replace("\u200d", '')

train_data = DataSourceController(
    filter=lambda x: len(x.label)<30, 
    transform= process_text
)
for k  in [
    'boise_camera_train', 
    'boise_scan_train', 
    'boise_conjunct_train', 
    'syn_boise_conjunct_train',
    # 'bn_grapheme_train',
    # 'syn_train'
    # "bangla_writting_train"
    # "bn_htr_train"

]:
    train_data.add_data(**train_source[k])

Out of 21026 boise_camera_train,21026 are kept after filtering
Total data 21026
Out of 20367 boise_scan_train,20367 are kept after filtering
Total data 41393
Out of 5798 boise_conjunct_train,5798 are kept after filtering
Total data 47191
Out of 25000 syn_boise_conjunct_train,5000 are kept after filtering
Total data 52191


In [9]:
val_data = DataSourceController(
    filter=lambda x: len(x.label)<30,
    transform= process_text
)
for k in[
    'boise_camera_val',
    'boise_scan_val', 
    'boise_conjunct_val', 
    # 'syn_val',
    # 'syn_boise_conjunct_val',
    # #'bn_grapheme_val',
    # "bangla_writting_val",
    # "bn_htr_val"

]:
    val_data.add_data(**val_source[k])

Out of 2630 boise_camera_val,2630 are kept after filtering
Total data 2630
Out of 2620 boise_scan_val,2620 are kept after filtering
Total data 5250
Out of 824 boise_conjunct_val,824 are kept after filtering
Total data 6074


In [10]:
train_dataset = TDataset(
    train_data.data,
    tokenizer,
    img_height= 32,
    img_width= 128,
    noiseAugment= Augmentation(.50)
)

# sampler=CustomDatasetSampler(
#     train_dataset,
#     num_samples = train_config['max_sample_per_epoch'],
#     labels = [mapper[d.id] for d in data.data], 
# ) 

dataloader = DataLoader(train_dataset, 
        batch_size= train_config['train_batch_size'], 
        collate_fn=collate_fn, 
        shuffle = True,#train_config['shuffle'],
        prefetch_factor = train_config['prefetch_factor'],
        num_workers = train_config['cpu_workers'],
        #sampler = sampler
)
len(dataloader)

Total 52191 Images found!!!


408

In [11]:
val_dataset = TDataset(
    val_data.data,
    tokenizer,
    img_height= 32,
    img_width= 128,
    #noiseAugment= NoiseAugment()
)

# val_sampler=CustomDatasetSampler(
#     val_dataset,
#     num_samples = 12800,
#     labels = [mapper[d.id] for d in val_data.data]
# ) 

val_dataloader = DataLoader(val_dataset, 
        batch_size= train_config['train_batch_size'], 
        collate_fn=collate_fn, 
        prefetch_factor = train_config['prefetch_factor'],
        num_workers = train_config['cpu_workers'],
        #sampler = val_sampler
        
    )
len(val_dataloader)

Total 6074 Images found!!!


48

In [None]:
train_config['max_iter'] = 170
train_config['epochs'] = 60

device = torch.device(
    'cuda' if torch.cuda.is_available() else 'cpu'
)

model = define_crnn_model(train_config, len(tokenizer.vocab) + 1)
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=train_config['lr'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode = 'min', 
    factor = .4, 
    patience = 5,
    verbose= True,
    min_lr= 0.000001
)

criterion = CTCLoss(
    reduction='sum', 
    zero_infinity=True
).to(device)

In [None]:
model = train(
    train_config, model, optimizer,scheduler,
    criterion, device, dataloader, val_dataloader
)