In [1]:
import os
import torch.backends.cudnn as cudnn
import yaml
from train import train
from utils import AttrDict, CTCLabelConverter
import pandas as pd

In [2]:
def get_config(file_path):
    with open(file_path, 'r', encoding="utf8") as stream:
        opt = yaml.safe_load(stream)
    opt = AttrDict(opt)
    if opt.lang_char == 'None':
        characters = ''
        for data in opt['select_data'].split('-'):
            csv_path = os.path.join(opt['train_data'], data, 'labels.csv')
            df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python', usecols=['filename', 'words'], keep_default_na=False)
            all_char = ''.join(df['words'])
            characters += ''.join(set(all_char))
        characters = sorted(set(characters))
        opt.character= ''.join(characters)
    else:
        opt.character = opt.number + opt.symbol + opt.lang_char
    os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
    return opt

In [3]:
opt = get_config("config_files/en_filtered_config.yaml")

In [4]:
converter = CTCLabelConverter(opt.character)
opt.num_class = len(converter.character)

In [5]:
from model import Model
model = Model(opt=opt)

In [7]:
import torch
from collections import OrderedDict

pretrained_dict = torch.load(opt.saved_model, map_location='cpu')
new_state_dict = OrderedDict()

for key, value in pretrained_dict.items():
    new_key = key[7:]
    new_state_dict[new_key] = value
model.load_state_dict(new_state_dict)

<All keys matched successfully>

In [8]:
from modules.quantization import QuantizationOps

qat_ops = QuantizationOps(model=model.FeatureExtraction)
model.FeatureExtraction = qat_ops.quantized_model
model = torch.nn.DataParallel(model).to('cpu')

QConfig(activation=functools.partial(<class 'torch.quantization.observer.HistogramObserver'>, reduce_range=True), weight=functools.partial(<class 'torch.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric))




In [13]:
model.module.FeatureExtraction

QuantizationVGG(
  (quant): QuantStub(
    (activation_post_process): HistogramObserver()
  )
  (dequant): DeQuantStub()
  (model_fp32): VGG_FeatureExtractor(
    (ConvNet): Sequential(
      (0): ConvReLU2d(
        1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
        (weight_fake_quant): PerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
        (activation_post_process): HistogramObserver()
      )
      (1): Identity()
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): ConvReLU2d(
        32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
        (weight_fake_quant): PerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
        (activation_post_process): HistogramObserver()
      )
      (4): Identity()
      (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): ConvReLU2d(
        64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
        (weight_f