In [1]:
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

from dataload import *
from model import *
import tensorflow as tf
import numpy as np
import time

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

In [2]:
epochs = 300
train_img_dir = './dataset/train'
test_img_dir = './dataset/eval'
initial_lr = 0.001
dropout_rate = 0.5
lpr_len = 7
train_batch_size = 128
test_batch_size = 128
saved_model_folder = './saved_reduced_model'
pretrained_model = ''

In [3]:
def metric(test_dataset, lprnet):
    
    tp, tp_error = 0, 0  # tp为正确预测数量, tp_error是错误预测数量
    start_time = time.time()
    for cur_batch, (test_imgs, test_labels) in enumerate(test_dataset):
        test_labels = tf.cast(test_labels, tf.int32)
        prebs = lprnet(test_imgs)
        preb_labels = list()
        for i in range(prebs.shape[0]):
            preb = prebs[i, :, :]
            preb_label = list()
            for j in range(preb.shape[1]):
                preb_label.append(np.argmax(preb[:, j]))
            no_repeat_blank_label = list()
            pre_c = preb_label[0]
            if pre_c != len(CHARS) - 1:
                no_repeat_blank_label.append(pre_c)
            for c in preb_label:
                if (pre_c == c) or (c == len(CHARS) - 1):
                    if c == (len(CHARS) - 1):
                        pre_c = c
                    continue
                no_repeat_blank_label.append(c)
                pre_c = c
            preb_labels.append(no_repeat_blank_label)
        for i, label in enumerate(preb_labels):
            if len(label) != len(test_labels[i]):
                tp_error += 1
                continue
            if (np.asarray(test_labels[i]) == np.asarray(label)).all():
                tp += 1
            else:
                tp_error += 1
                
    end_time = time.time()
    t = end_time - start_time
    acc = tp * 1.0 / (tp + tp_error)  # 准确率
    
    return acc, tp, tp_error, t

In [4]:
# 加载训练数据集
tr_imgs, tr_labels, train_imgs_num = load_data(img_dir=train_img_dir, lpr_len=lpr_len)
tr_imgs = tr_imgs.map(preprocess,num_parallel_calls=4)
train_dataset = tf.data.Dataset.zip((tr_imgs, tr_labels)).shuffle(train_imgs_num).batch(train_batch_size).prefetch(tf.data.experimental.AUTOTUNE).cache()

# 加载测试数据集
te_imgs, te_labels, test_imgs_num = load_data(img_dir=test_img_dir, lpr_len=lpr_len)
te_imgs = te_imgs.map(preprocess,num_parallel_calls=4)
test_dataset = tf.data.Dataset.zip((te_imgs, te_labels)).batch(test_batch_size).prefetch(tf.data.experimental.AUTOTUNE).cache()

In [7]:
# 训练模型
if not os.path.exists(saved_model_folder):
    os.mkdir(saved_model_folder)

# 实例化模型
lprnet = LPRNet(lpr_len=lpr_len, class_num=len(CHARS), dropout_rate=dropout_rate)
print("********** Successful to build network! **********\n")

# 加载预训练模型
if pretrained_model:
    lprnet.load_weights(pretrained_model)
    print("********** Successful to load pretrained model! **********")

# 优化器使用 Adam
optimizer = tf.keras.optimizers.Adam(learning_rate=initial_lr)

# 模型训练
top_acc = 0.
for cur_epoch in range(1, epochs + 1):
    batch = 0
    for batch_index, (train_imgs, train_labels) in enumerate(train_dataset):  # 修改的地方
        start_time = time.time()
        with tf.GradientTape() as tape:
            train_logits = lprnet(train_imgs) #[N,66,18]
            train_labels = tf.cast(train_labels, tf.int32) #[N,7]
            train_logits = tf.transpose(train_logits, [2, 0, 1]) #[18,N,66]
            logits_shape = train_logits.shape
            logit_length = tf.fill([logits_shape[1]], logits_shape[0]) #(N)
            label_length = tf.fill([logits_shape[1]], lpr_len)
            loss = tf.nn.ctc_loss(labels=train_labels,
                                  logits=train_logits,
                                  label_length=label_length,
                                  logit_length=logit_length,
                                  logits_time_major=True,
                                  blank_index=len(CHARS) - 1)
            loss = tf.reduce_mean(loss)
        grads = tape.gradient(loss, lprnet.variables)
        optimizer.apply_gradients(grads_and_vars=zip(grads, lprnet.variables))
        end_time = time.time()
        batch = batch + int(np.shape(train_imgs)[0])
        print('\r' + "Epoch {0}/{1} || ".format(cur_epoch, epochs) 
                  + "Batch {0}/{1} || ".format(batch, train_imgs_num) # 修改的地方
                  + "Loss:{} || ".format(loss) 
                  + "A Batch time:{0:.4f}s || ".format(end_time - start_time)
                  + "Learning rate:{0:.8f} || ".format(optimizer.lr.numpy().item()), end=''*20)
    acc, tp, tp_error, t = metric(test_dataset, lprnet) # 修改的地方
    print("\n******* Prediction {0}/{1} || Acc:{2:.2f}% *******".format(tp, tp + tp_error, acc*100))
    print("******* Test speed: {}s 1/{} *******".format(t / (tp + tp_error), tp + tp_error))
        
    # 保存模型
    if acc >= top_acc:
        top_acc = acc
        lprnet.save(saved_model_folder, save_format='tf')

# 将.pb模型转为.tflite
cvtmodel = tf.keras.models.load_model(saved_model_folder)
converter = tf.lite.TFLiteConverter.from_keras_model(cvtmodel)
tflite_model = converter.convert()
with open('lprnet' + '{}'.format(np.around(top_acc * 100)) + '.tflite', "wb") as f:
    f.write(tflite_model)
print("\n ********** Successful to convert tflite model! ********** \n")

********** Successful to build network! **********

Epoch 1/300 || Batch 208075/208075 || Loss:13.095730781555176 || A Batch time:0.7216s || Learning rate:0.00100000 || 
******* Prediction 221/6436 || Acc:3.43% *******
******* Test speed: 0.007432138349344599s 1/6436 *******
Epoch 2/300 || Batch 208075/208075 || Loss:5.929137706756592 || A Batch time:0.1720s || Learning rate:0.00100000 ||  
******* Prediction 2849/6436 || Acc:44.27% *******
******* Test speed: 0.007605572751473758s 1/6436 *******
Epoch 3/300 || Batch 208075/208075 || Loss:3.994891881942749 || A Batch time:0.1564s || Learning rate:0.00100000 ||  
******* Prediction 3876/6436 || Acc:60.22% *******
******* Test speed: 0.007074955496601321s 1/6436 *******
Epoch 4/300 || Batch 208075/208075 || Loss:2.946272850036621 || A Batch time:0.1367s || Learning rate:0.00100000 ||  
******* Prediction 4183/6436 || Acc:64.99% *******
******* Test speed: 0.007110061039459342s 1/6436 *******
Epoch 5/300 || Batch 208075/208075 || Loss:2.6

Epoch 37/300 || Batch 208075/208075 || Loss:0.8852901458740234 || A Batch time:0.1684s || Learning rate:0.00100000 ||  
******* Prediction 5346/6436 || Acc:83.06% *******
******* Test speed: 0.00709515591450105s 1/6436 *******
Epoch 38/300 || Batch 208075/208075 || Loss:0.8256071209907532 || A Batch time:0.1423s || Learning rate:0.00100000 ||  
******* Prediction 5372/6436 || Acc:83.47% *******
******* Test speed: 0.007141555185418754s 1/6436 *******
Epoch 39/300 || Batch 208075/208075 || Loss:0.7761783599853516 || A Batch time:0.1570s || Learning rate:0.00100000 ||  
******* Prediction 5360/6436 || Acc:83.28% *******
******* Test speed: 0.007014059405507621s 1/6436 *******
Epoch 40/300 || Batch 208075/208075 || Loss:0.7253499031066895 || A Batch time:0.1444s || Learning rate:0.00100000 ||  
******* Prediction 5364/6436 || Acc:83.34% *******
******* Test speed: 0.007159635545008667s 1/6436 *******
Epoch 41/300 || Batch 208075/208075 || Loss:0.7408726811408997 || A Batch time:0.1439s ||

Epoch 73/300 || Batch 208075/208075 || Loss:0.4741989076137543 || A Batch time:0.1425s || Learning rate:0.00100000 ||  
******* Prediction 5416/6436 || Acc:84.15% *******
******* Test speed: 0.007138008097819915s 1/6436 *******
Epoch 74/300 || Batch 208075/208075 || Loss:0.513117253780365 || A Batch time:0.1388s || Learning rate:0.00100000 ||   
******* Prediction 5412/6436 || Acc:84.09% *******
******* Test speed: 0.007110247373432754s 1/6436 *******
Epoch 75/300 || Batch 208075/208075 || Loss:0.4764088988304138 || A Batch time:0.1528s || Learning rate:0.00100000 ||  
******* Prediction 5405/6436 || Acc:83.98% *******
******* Test speed: 0.007138868716283197s 1/6436 *******
Epoch 76/300 || Batch 208075/208075 || Loss:0.4878706932067871 || A Batch time:0.1535s || Learning rate:0.00100000 ||  
******* Prediction 5400/6436 || Acc:83.90% *******
******* Test speed: 0.007010970077135312s 1/6436 *******
Epoch 77/300 || Batch 208075/208075 || Loss:0.4463517665863037 || A Batch time:0.1712s |

Epoch 109/300 || Batch 208075/208075 || Loss:0.43089157342910767 || A Batch time:0.1374s || Learning rate:0.00100000 || 
******* Prediction 5364/6436 || Acc:83.34% *******
******* Test speed: 0.007138236662554696s 1/6436 *******
Epoch 110/300 || Batch 208075/208075 || Loss:0.4038713872432709 || A Batch time:0.1383s || Learning rate:0.00100000 ||  
******* Prediction 5370/6436 || Acc:83.44% *******
******* Test speed: 0.007099379620394994s 1/6436 *******
Epoch 111/300 || Batch 208075/208075 || Loss:0.40903979539871216 || A Batch time:0.1420s || Learning rate:0.00100000 || 
******* Prediction 5385/6436 || Acc:83.67% *******
******* Test speed: 0.007065709182532994s 1/6436 *******
Epoch 112/300 || Batch 208075/208075 || Loss:0.39524638652801514 || A Batch time:0.1432s || Learning rate:0.00100000 || 
******* Prediction 5380/6436 || Acc:83.59% *******
******* Test speed: 0.007073337206416726s 1/6436 *******
Epoch 113/300 || Batch 208075/208075 || Loss:0.38290753960609436 || A Batch time:0.1

Epoch 145/300 || Batch 208075/208075 || Loss:0.3741193115711212 || A Batch time:0.1422s || Learning rate:0.00100000 ||  
******* Prediction 5362/6436 || Acc:83.31% *******
******* Test speed: 0.007143502134651483s 1/6436 *******
Epoch 146/300 || Batch 208075/208075 || Loss:0.38366439938545227 || A Batch time:0.1371s || Learning rate:0.00100000 || 
******* Prediction 5355/6436 || Acc:83.20% *******
******* Test speed: 0.007113206416201784s 1/6436 *******
Epoch 147/300 || Batch 208075/208075 || Loss:0.41459929943084717 || A Batch time:0.1550s || Learning rate:0.00100000 || 
******* Prediction 5352/6436 || Acc:83.16% *******
******* Test speed: 0.007157527489125765s 1/6436 *******
Epoch 148/300 || Batch 208075/208075 || Loss:0.37706559896469116 || A Batch time:0.1434s || Learning rate:0.00100000 || 
******* Prediction 5362/6436 || Acc:83.31% *******
******* Test speed: 0.007167167216076326s 1/6436 *******
Epoch 149/300 || Batch 208075/208075 || Loss:0.3695860505104065 || A Batch time:0.13

KeyboardInterrupt: 