In [None]:
!pip -q install -U scikit-image
!pip -q install einops==0.2.0
!pip -q install gdown==4.4.0
!pip -q install prefetch_generator==1.0.1
!pip -q install imgaug==0.4.0
!pip -q install lmdb

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.7/14.7 MB[0m [31m36.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for gdown (pyproject.toml) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for prefetch_generator (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m299.2/299.2 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#Load model
!gdown 1VU-qucVHcNbu2qnGkWWNwO0TRhY-7iH2
!unzip -q vietocr.zip
!unzip -q weights.zip

#!cp '/content/drive/MyDrive/2023-Projects/Kalapa 2023/Vietnamese Handwritten Text Recognition/model/vietocr.zip' ./
#!unzip -q vietocr.zip
#!unzip -q weights.zip

In [None]:
from vietocr import *

In [None]:
import os
import PIL
import copy
import yaml
import time
import torch
import numpy as np
import torchvision
from torch import nn
from PIL import Image
from einops import rearrange
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.optim import Adam, SGD, AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, CyclicLR, OneCycleLR

In [None]:
#Load predictor
def load_config(yml_path):
  with open(yml_path, "r") as stream:
    try:
      config = yaml.safe_load(stream)
      return config
    except yaml.YAMLError as exc:
      print(exc)

def save_models(model, file_name):
    output_path = './weights/'
    if not os.path.exists(output_path):
        os.mkdir(output_path)
    saved_path = os.path.join(output_path, file_name)
    if os.path.exists(saved_path):
        os.remove(saved_path)
    print('Save files in: ', saved_path)
    torch.save(model.state_dict(), saved_path)
def save_torchscript_model(model, file_name):
    output_path = './weights/'
    if not os.path.exists(output_path):
        os.mkdir(output_path)
    model_filepath = os.path.join(output_path, file_name)
    torch.jit.save(torch.jit.script(model), model_filepath)
    print('Save in: ', model_filepath)
    return model_filepath
def load_torchscript_model(model_filepath, device):
    model = torch.jit.load(model_filepath, map_location=device)
    return model

class QuantizedCNN(nn.Module):
    def __init__(self, model_fp32):
        super(QuantizedCNN, self).__init__()
        self.quant = torch.quantization.QuantStub() # QuantStub converts tensors from floating point to quantized. This will only be used for inputs.
        self.dequant = torch.quantization.DeQuantStub() # DeQuantStub converts tensors from quantized to floating point. This will only be used for outputs.
        self.model_fp32 = model_fp32 # FP32 model
    def forward(self, x):
        x = self.quant(x) #manually specify where tensors will be converted from floating point to quantized in the quantized model
        x = self.model_fp32(x)
        x = self.dequant(x) #manually specify where tensors will be converted from quantized to floating point in the quantized model
        return x

In [None]:
#Update Trainer
class Trainer():
    def __init__(self, config, qmodel=None, pretrained=True, augmentor=ImgAugTransform()):
        self.config = config
        if qmodel is not None:
            _, self.vocab = build_model(config)
            self.model = qmodel
        else:
            self.model, self.vocab = build_model(config)
        self.device = config['device']
        self.num_iters = config['trainer']['iters']
        self.beamsearch = config['predictor']['beamsearch']
        self.data_root = config['dataset']['data_root']
        self.train_annotation = config['dataset']['train_annotation']
        self.valid_annotation = config['dataset']['valid_annotation']
        self.dataset_name = config['dataset']['name']
        self.batch_size = config['trainer']['batch_size']
        self.print_every = config['trainer']['print_every']
        self.valid_every = config['trainer']['valid_every']
        self.image_aug = config['aug']['image_aug']
        self.masked_language_model = config['aug']['masked_language_model']
        self.checkpoint = config['trainer']['checkpoint']
        self.export_weights = config['trainer']['export']
        self.metrics = config['trainer']['metrics']
        logger = config['trainer']['log']
        if logger:
            self.logger = Logger(logger)
        if pretrained:
            weight_file = download_weights(config['pretrain'], quiet=config['quiet'])
            self.load_weights(weight_file)
        self.iter = 0
        self.optimizer = AdamW(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09)
        self.scheduler = OneCycleLR(self.optimizer, total_steps=self.num_iters, **config['optimizer'])
        self.criterion = LabelSmoothingLoss(len(self.vocab), padding_idx=self.vocab.pad, smoothing=0.1)
        transforms = None
        if self.image_aug:
            transforms =  augmentor
        self.train_gen = self.data_gen('train_{}'.format(self.dataset_name),
                self.data_root, self.train_annotation, self.masked_language_model, transform=transforms)
        if self.valid_annotation:
            self.valid_gen = self.data_gen('valid_{}'.format(self.dataset_name),
                    self.data_root, self.valid_annotation, masked_language_model=False)
        self.train_losses = []
    def train(self):
        total_loss = 0
        total_loader_time = 0
        total_gpu_time = 0
        best_acc = 0
        data_iter = iter(self.train_gen)
        for i in range(self.num_iters):
            self.iter += 1
            start = time.time()
            try:
                batch = next(data_iter)
            except StopIteration:
                data_iter = iter(self.train_gen)
                batch = next(data_iter)
            total_loader_time += time.time() - start
            start = time.time()
            loss = self.step(batch)
            total_gpu_time += time.time() - start
            total_loss += loss
            self.train_losses.append((self.iter, loss))
            if self.iter % self.print_every == 0:
                info = 'iter: {:06d} - train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format(self.iter,
                        total_loss/self.print_every, self.optimizer.param_groups[0]['lr'],
                        total_loader_time, total_gpu_time)
                total_loss = 0
                total_loader_time = 0
                total_gpu_time = 0
                print(info)
                self.logger.log(info)
            if self.valid_annotation and self.iter % self.valid_every == 0:
                val_loss = self.validate()
                acc_full_seq, acc_per_char = self.precision(self.metrics)
                info = 'iter: {:06d} - valid loss: {:.3f} - acc full seq: {:.4f} - acc per char: {:.4f}'.format(self.iter, val_loss, acc_full_seq, acc_per_char)
                print(info)
                self.logger.log(info)
                if acc_full_seq > best_acc:
                    self.save_weights(self.export_weights)
                    best_acc = acc_full_seq
    def validate(self):
        self.model.eval()
        total_loss = []
        with torch.no_grad():
            for step, batch in enumerate(self.valid_gen):
                batch = self.batch_to_device(batch)
                img, tgt_input, tgt_output, tgt_padding_mask = batch['img'], batch['tgt_input'], batch['tgt_output'], batch['tgt_padding_mask']
                outputs = self.model(img, tgt_input, tgt_padding_mask)
                outputs = outputs.flatten(0,1)
                tgt_output = tgt_output.flatten()
                loss = self.criterion(outputs, tgt_output)
                total_loss.append(loss.item())
                del outputs
                del loss
        total_loss = np.mean(total_loss)
        self.model.train()
        return total_loss
    def predict(self, sample=None):
        pred_sents = []
        actual_sents = []
        img_files = []
        for batch in  self.valid_gen:
            batch = self.batch_to_device(batch)
            if self.beamsearch:
                translated_sentence = batch_translate_beam_search(batch['img'], self.model)
                prob = None
            else:
                translated_sentence, prob = translate(batch['img'], self.model)
            pred_sent = self.vocab.batch_decode(translated_sentence.tolist())
            actual_sent = self.vocab.batch_decode(batch['tgt_output'].tolist())
            img_files.extend(batch['filenames'])
            pred_sents.extend(pred_sent)
            actual_sents.extend(actual_sent)
            if sample != None and len(pred_sents) > sample:
                break
        return pred_sents, actual_sents, img_files, prob
    def precision(self, sample=None):
        pred_sents, actual_sents, _, _ = self.predict(sample=sample)
        acc_full_seq = compute_accuracy(actual_sents, pred_sents, mode='full_sequence')
        acc_per_char = compute_accuracy(actual_sents, pred_sents, mode='per_char')
        return acc_full_seq, acc_per_char
    def visualize_prediction(self, sample=16, errorcase=False, fontname='serif', fontsize=16):
        pred_sents, actual_sents, img_files, probs = self.predict(sample)
        if errorcase:
            wrongs = []
            for i in range(len(img_files)):
                if pred_sents[i]!= actual_sents[i]:
                    wrongs.append(i)
            pred_sents = [pred_sents[i] for i in wrongs]
            actual_sents = [actual_sents[i] for i in wrongs]
            img_files = [img_files[i] for i in wrongs]
            probs = [probs[i] for i in wrongs]
        img_files = img_files[:sample]
        fontdict = {'family':fontname, 'size':fontsize}
        for vis_idx in range(0, len(img_files)):
            img_path = img_files[vis_idx]
            pred_sent = pred_sents[vis_idx]
            actual_sent = actual_sents[vis_idx]
            prob = probs[vis_idx]
            img = Image.open(open(img_path, 'rb'))
            plt.figure()
            plt.imshow(img)
            plt.title('prob: {:.3f} - pred: {} - actual: {}'.format(prob, pred_sent, actual_sent), loc='left', fontdict=fontdict)
            plt.axis('off')
        plt.show()
    def visualize_dataset(self, sample=16, fontname='serif'):
        n = 0
        for batch in self.train_gen:
            for i in range(self.batch_size):
                img = batch['img'][i].numpy().transpose(1,2,0)
                sent = self.vocab.decode(batch['tgt_input'].T[i].tolist())
                plt.figure()
                plt.title('sent: {}'.format(sent), loc='center', fontname=fontname)
                plt.imshow(img)
                plt.axis('off')
                n += 1
                if n >= sample:
                    plt.show()
                    return
    def load_checkpoint(self, filename):
        checkpoint = torch.load(filename)
        optim = ScheduledOptim(
	       Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09),
            	self.config['transformer']['d_model'], **self.config['optimizer'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.model.load_state_dict(checkpoint['state_dict'])
        self.iter = checkpoint['iter']
        self.train_losses = checkpoint['train_losses']
    def save_checkpoint(self, filename):
        state = {'iter':self.iter, 'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(), 'train_losses': self.train_losses}
        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)
        torch.save(state, filename)
    def load_weights(self, filename):
        state_dict = torch.load(filename, map_location=torch.device(self.device))
        for name, param in self.model.named_parameters():
            if name not in state_dict:
                print('{} not found'.format(name))
            elif state_dict[name].shape != param.shape:
                print('{} missmatching shape, required {} but found {}'.format(name, param.shape, state_dict[name].shape))
                del state_dict[name]
        self.model.load_state_dict(state_dict, strict=False)
    def save_weights(self, filename):
        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)
        qmodel = copy.deepcopy(self.model)
        qmodel.to(torch.device('cpu'))
        qmodel.cnn = torch.quantization.convert(qmodel.cnn.eval(), inplace=True)
        qmodel.eval()
        torch.save(qmodel.state_dict(), filename)
    def batch_to_device(self, batch):
        img = batch['img'].to(self.device, non_blocking=True)
        tgt_input = batch['tgt_input'].to(self.device, non_blocking=True)
        tgt_output = batch['tgt_output'].to(self.device, non_blocking=True)
        tgt_padding_mask = batch['tgt_padding_mask'].to(self.device, non_blocking=True)
        batch = {'img': img, 'tgt_input':tgt_input, 'tgt_output':tgt_output, 'tgt_padding_mask':tgt_padding_mask, 'filenames': batch['filenames']}
        return batch
    def data_gen(self, lmdb_path, data_root, annotation, masked_language_model=True, transform=None):
        dataset = OCRDataset(lmdb_path=lmdb_path,
                root_dir=data_root, annotation_path=annotation,
                vocab=self.vocab, transform=transform,
                image_height=self.config['dataset']['image_height'],
                image_min_width=self.config['dataset']['image_min_width'],
                image_max_width=self.config['dataset']['image_max_width'])
        sampler = ClusterRandomSampler(dataset, self.batch_size, True)
        collate_fn = Collator(masked_language_model)
        gen = DataLoader(dataset, batch_size=self.batch_size, sampler=sampler,
                         collate_fn = collate_fn, shuffle=False, drop_last=False, **self.config['dataloader'])
        return gen
    def data_gen_v1(self, lmdb_path, data_root, annotation):
        data_gen = DataGen(data_root, annotation, self.vocab, 'cpu',
                image_height = self.config['dataset']['image_height'],
                image_min_width = self.config['dataset']['image_min_width'],
                image_max_width = self.config['dataset']['image_max_width'])
        return data_gen
    def step(self, batch):
        self.model.train()
        batch = self.batch_to_device(batch)
        img, tgt_input, tgt_output, tgt_padding_mask = batch['img'], batch['tgt_input'], batch['tgt_output'], batch['tgt_padding_mask']
        outputs = self.model(img, tgt_input, tgt_key_padding_mask=tgt_padding_mask)
        outputs = outputs.view(-1, outputs.size(2))#flatten(0, 1)
        tgt_output = tgt_output.view(-1)#flatten()
        loss = self.criterion(outputs, tgt_output)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
        self.optimizer.step()
        self.scheduler.step()
        loss_item = loss.item()
        return loss_item

In [None]:
#Load dataset
!gdown 1LaWUP_jCaoJJeEOOYHJEUvtTaGByT3c7
!unzip -q quantization_data.zip
#!cp '/content/drive/MyDrive/2023-Projects/Kalapa 2023/Vietnamese Handwritten Text Recognition/dataset/quantization_data.zip' ./
#!unzip -q quantization_data.zip

In [None]:
#Update config
config = load_config('/content/base.yml')
dataset_params = {'name':'full_dataset',
                  'data_root':'./quantization_data',
                  'train_annotation':'train_line_annotation.txt',
                  'valid_annotation':'test_line_annotation.txt'}
params = {'batch_size': 1, 'print_every': 200, 'valid_every': 400,'iters': 5400,
          'checkpoint':'./quantize_checkpoint/quantize_transformerocr.pth',
          'export':'./quantize_weights/quantize_transformerocr.pth',
          'log': './train.log', 'metrics': None}

config['trainer'].update(params)
config['dataset'].update(dataset_params)
config['weights'] = "./weights/transformerocr.pth"
config['cnn']['pretrained']=False
device = config['device']
config

{'vocab': 'aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ ',
 'device': 'cpu',
 'seq_modeling': 'transformer',
 'transformer': {'d_model': 256,
  'nhead': 8,
  'num_encoder_layers': 6,
  'num_decoder_layers': 6,
  'dim_feedforward': 2048,
  'max_seq_length': 1024,
  'pos_dropout': 0.1,
  'trans_dropout': 0.1},
 'optimizer': {'max_lr': 0.0003, 'pct_start': 0.1},
 'trainer': {'batch_size': 1,
  'print_every': 200,
  'valid_every': 400,
  'iters': 5400,
  'export': './quantize_weights/quantize_transformerocr.pth',
  'checkpoint': './quantize_checkpoint/quantize_transformerocr.pth',
  'log': './train.log',
  'metrics': None},
 'dataset': {'name': 'full_dataset',
  'data_root': './quantization_data',
  'train_annotation': 'train_line_annotation.txt',
  'valid_annotation': 'test_line_annotation.txt',
  'image_h

In [None]:
# Fuse layer
model, vocab = build_model(config)
weights = config['weights']
model.load_state_dict(torch.load(weights, map_location=torch.device(device)))
model = model.eval()
for m in model.cnn.model.modules():
    if type(m) == nn.Sequential:
        for n, layer in enumerate(m):
            if type(layer) == nn.Conv2d:
                torch.quantization.fuse_modules(m, [str(n), str(n + 1), str(n + 2)], inplace=True)

# Prepare the model for quantization aware training.
quantized_cnn = QuantizedCNN(model_fp32=model.cnn)
quantized_cnn.qconfig = torch.quantization.get_default_qconfig("fbgemm")
# Print quantization configurations
print(quantized_cnn.qconfig)
# the prepare() is used in post training quantization to prepares your model for the calibration step quantized_cnn = torch.quantization.prepare_qat(quantized_cnn, inplace=True)
quantized_cnn = torch.quantization.prepare_qat(quantized_cnn.train(), inplace=True)
model.cnn = quantized_cnn

model.train()
model = model.to(device)

In [None]:
#Re-training
trainer = Trainer(config=config, qmodel=model, pretrained=False)
trainer.train()

Create train_full_dataset: 100%|█████████████████████████████████| 922/922 [00:05<00:00, 157.41it/s]


Created dataset with 921 samples


train_full_dataset build cluster: 100%|████████████████████████| 921/921 [00:00<00:00, 69443.87it/s]
Create valid_full_dataset: 100%|█████████████████████████████████| 517/517 [00:03<00:00, 140.77it/s]


Created dataset with 516 samples


valid_full_dataset build cluster: 100%|████████████████████████| 516/516 [00:00<00:00, 41897.57it/s]


iter: 000200 - train loss: 0.997 - lr: 9.93e-05 - load time: 7.66 - gpu time: 415.41
iter: 000400 - train loss: 1.148 - lr: 2.55e-04 - load time: 5.88 - gpu time: 420.43
iter: 000400 - valid loss: 1.123 - acc full seq: 0.1531 - acc per char: 0.6499
iter: 000600 - train loss: 1.317 - lr: 3.00e-04 - load time: 6.38 - gpu time: 402.53
iter: 000800 - train loss: 1.301 - lr: 2.98e-04 - load time: 7.89 - gpu time: 404.06
iter: 000800 - valid loss: 1.226 - acc full seq: 0.0717 - acc per char: 0.5082
iter: 001000 - train loss: 1.336 - lr: 2.93e-04 - load time: 7.32 - gpu time: 406.37
iter: 001200 - train loss: 1.297 - lr: 2.87e-04 - load time: 6.54 - gpu time: 407.34
iter: 001200 - valid loss: 1.102 - acc full seq: 0.1860 - acc per char: 0.6636
iter: 001400 - train loss: 1.281 - lr: 2.77e-04 - load time: 7.91 - gpu time: 409.95
iter: 001600 - train loss: 1.283 - lr: 2.66e-04 - load time: 7.17 - gpu time: 410.15
iter: 001600 - valid loss: 1.140 - acc full seq: 0.1453 - acc per char: 0.5773
iter