In [1]:
import os
import torch.backends.cudnn as cudnn
import yaml
from train import train
from utils import AttrDict, CTCLabelConverter
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cudnn.benchmark = True
cudnn.deterministic = False

In [3]:
def get_config(file_path):
    with open(file_path, 'r', encoding="utf8") as stream:
        opt = yaml.safe_load(stream)
    opt = AttrDict(opt)
    if opt.lang_char == 'None':
        characters = ''
        for data in opt['select_data'].split('-'):
            csv_path = os.path.join(opt['train_data'], data, 'labels.csv')
            df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python', usecols=['filename', 'words'], keep_default_na=False)
            all_char = ''.join(df['words'])
            characters += ''.join(set(all_char))
        characters = sorted(set(characters))
        opt.character= ''.join(characters)
    else:
        opt.character = opt.number + opt.symbol + opt.lang_char
    os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
    return opt

In [4]:
opt = get_config("config_files/en_filtered_config.yaml")

In [5]:
converter = CTCLabelConverter(opt.character)
opt.num_class = len(converter.character)

In [6]:
from model import Model
model = Model(opt=opt)

In [7]:
import torch
from collections import OrderedDict

weights_path = './weights/english.pth'
state_dict = torch.load(weights_path, map_location='cpu')
new_state_dict = OrderedDict()

for key, value in state_dict.items():
    new_key = key[7:]
    new_state_dict[new_key] = value

In [8]:
model.load_state_dict(new_state_dict)

<All keys matched successfully>

In [9]:
model.FeatureExtraction.eval()

VGG_FeatureExtractor(
  (ConvNet): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padd

In [10]:
import torch.nn as nn

layer_fused = [['0', '1'], ['3', '4'], ['6', '7'], \
               ['8', '9'], ['18', '19'], \
               ['11', '12', '13'], ['14', '15', '16']]

layer_fused[0]

['0', '1']

In [11]:
model = model.train()

for m in model.FeatureExtraction.modules():
    if type(m) == nn.Sequential:
        for layer in layer_fused:
            torch.quantization.fuse_modules(m, layer, inplace=True)

In [12]:
model.FeatureExtraction.eval()

VGG_FeatureExtractor(
  (ConvNet): Sequential(
    (0): ConvReLU2d(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
    )
    (1): Identity()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): ConvReLU2d(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
    )
    (4): Identity()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): ConvReLU2d(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
    )
    (7): Identity()
    (8): ConvReLU2d(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
    )
    (9): Identity()
    (10): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)
    (11): ConvBnReLU2d(
      (0): Conv2d(128, 256, kernel_size=(3, 3), 