# 定义加载数据的工具

In [1]:
import torch

from data.ccpd2lpr import licenseDataset
from torch.utils.data import DataLoader
pathToCCPD=r'C:\CCPD\CCPD'
trainDataSet=licenseDataset(pathToCCPD,['train'],False)
valDataSet=licenseDataset(pathToCCPD,['val'],False)

trainDataLoader=DataLoader(dataset=trainDataSet,batch_size=32,shuffle=True,num_workers=16)
valDataLoader=DataLoader(dataset=valDataSet,batch_size=32,num_workers=16,drop_last=True)

# 加载LPRNet

In [2]:
from LPRNet.LPRNet import LPRNet
pretainModelPath=r"./weights/LPRNet/myLPRNet.pt"
lprnet=LPRNet(lpr_max_len=8,class_num=68,dropout_rate=0.5)
lprnet.load_state_dict(torch.load(pretainModelPath))

In [3]:
import torch.nn
import torch
from data.ccpd2lpr import CHARS
lr=0.001
optmizer=torch.optim.RMSprop(lprnet.parameters(),lr)
criterion=torch.nn.CTCLoss(blank=len(CHARS)-1, reduction='mean')
device=torch.device(("cuda" if torch.cuda.is_available() else "cpu"))
total_epoch=100

In [4]:
# # 测试模型的运行并绘制模型结构
# import onnx
# import torch.onnx
# lprnet.to(device) # module类的to是in place操作
# inputs=torch.randn(1,3,24,94,dtype=torch.float32).to(device) # tensor类的to不是in place操作
# output=lprnet(inputs)
# output=output.transpose(0, 2)
# output=output.transpose(1, 2)
# print(output.shape)
# torch.onnx.export(lprnet, inputs, r"./models/showLPRNet.onnx", verbose=True, input_names=["input"], output_names=["output"])

In [None]:
# 判断是否是最优模型
def isbest(acc_list):
    lastest_acc=acc_list[-1]
    if lastest_acc==max(acc_list):
        return True
    else:
        return False

In [None]:
# 定义学习率调度器
def create_lr_scheduler(optimizer,num_step:int,epochs:int,warmup=True,warmup_epochs=10,warmup_factor=1e-3):
    assert num_step>0 and epochs>0
    if warmup is False:
        warmup_epochs=0
    def f(x):
        if warmup is True and x<=(warmup_epochs*num_step):
            alpha=float(x)/(warmup_epochs*num_step)
            return warmup_factor*(1-alpha)+alpha
        else:
            return (1-(x-warmup_epochs*num_step)/((epochs-warmup_epochs)*num_step))**0.9
    return torch.optim.lr_scheduler.LambdaLR(optimizer,f)
lr_scheduler=create_lr_scheduler(optmizer,len(trainDataLoader),total_epoch)

# 训练模型

In [5]:
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np

lprnet.to(device)
running_losses=[]
running_acces=[]


def Greedy_Decode_Eval(Net):
    Net = Net.eval()
    Tp = 0
    Tn_1 = 0
    Tn_2 = 0
    for images, labels, lengths ,_ in valDataLoader:
        images=images.to(device)
        labels=labels.numpy()
        lengths=lengths.numpy()
        targets=[]
        for i in range(len(labels)):
            target=labels[i]
            target=target[:lengths[i]:]
            targets.append(target)
        # forward
        # images: [bs, 3, 24, 94]
        # prebs:  [bs, 68, 18]
        prebs = Net(images)
        # greedy decode
        prebs = prebs.cpu().detach().numpy()
        preb_labels = list()
        for i in range(prebs.shape[0]):
            preb = prebs[i, :, :]  # 对每张图片 [68, 18]
            preb_label = list()
            for j in range(preb.shape[1]):  # 18  返回序列中每个位置最大的概率对应的字符idx  其中'-'是67
                preb_label.append(np.argmax(preb[:, j], axis=0))

            no_repeat_blank_label = list()
            pre_c = preb_label[0]
            if pre_c != len(CHARS) - 1:  # 记录重复字符
                no_repeat_blank_label.append(pre_c)
            for c in preb_label:  # 去除重复字符和空白字符'-'
                if (pre_c == c) or (c == len(CHARS) - 1):
                    if c == len(CHARS) - 1:
                        pre_c = c
                    continue
                no_repeat_blank_label.append(c)
                pre_c = c
            preb_labels.append(no_repeat_blank_label)  # 得到最终的无重复字符和无空白字符的序列

        for i, label in enumerate(preb_labels):  # 统计准确率
            if len(label) != len(targets[i]):
                Tn_1 += 1  # 错误+1
                continue
            if (np.asarray(targets[i]) == np.asarray(label)).all():
                Tp += 1  # 完全正确+1
            else:
                Tn_2 += 1
    Acc = Tp * 1.0 / (Tp + Tn_1 + Tn_2)
    print("[Info] Test Accuracy: {} [{}:{}:{}:{}]".format(Acc, Tp, Tn_1, Tn_2, (Tp+Tn_1+Tn_2)))
    return Acc

def trainLPRNet(lprnet):
    for epoch in tqdm(range(total_epoch), desc='Training model', ncols=100):
            lprnet.train()
            running_loss=0.0
            for images,labels,labelLengths,_ in trainDataLoader:
                images,labels,labelLengths=images.to(device),labels.to(device),labelLengths.to(device)
                optmizer.zero_grad()
                outputs=lprnet(images) # batchsize*68*18
                outputs=outputs.transpose(0, 2)
                outputs=outputs.transpose(1, 2) # 18,batchsize,68
                outputs=F.log_softmax(outputs,dim=2)
                output_lengths=torch.full((len(labelLengths),),18)
                loss=criterion(outputs,labels,output_lengths,labelLengths)
                running_loss+=loss
                loss.backward()
                optmizer.step()
                lr_scheduler.step()
            running_loss/=trainDataLoader.batch_size
            running_losses.append(running_loss)
            acc=Greedy_Decode_Eval(lprnet)
            print(f"epoch: {epoch} loss {running_loss}")
            running_acces.append(acc)
            if isbest(running_acces):
                torch.save(lprnet.state_dict(),'./mybestLPRNet.pt')

In [7]:
trainLPRNet(lprnet)

Training model:   2%|▉                                            | 1/50 [04:15<3:28:36, 255.43s/it]

[Info] Test Accuracy: 0.0 [0:100991:1:100992]
epoch: 0 loss 0.7365410923957825


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x000001C8BA6B3700>
Traceback (most recent call last):
  File "D:\miniconda\envs\cv_env\lib\site-packages\torch\utils\data\dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "D:\miniconda\envs\cv_env\lib\site-packages\torch\utils\data\dataloader.py", line 1442, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "D:\miniconda\envs\cv_env\lib\multiprocessing\process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "D:\miniconda\envs\cv_env\lib\multiprocessing\popen_spawn_win32.py", line 108, in wait
    res = _winapi.WaitForSingleObject(int(self._handle), msecs)
KeyboardInterrupt: 
Training model:   2%|▉                                            | 1/50 [06:06<4:59:32, 366.79s/it]


KeyboardInterrupt: 

In [6]:
Greedy_Decode_Eval(lprnet)

[Info] Test Accuracy: 0.0 [0:100987:5:100992]


0.0

In [9]:
torch.save(lprnet.state_dict(),'./mylastLPRNet.pt')

In [None]:
lprnet=lprnet.load_state_dict(torch.load())