In [1]:
import math, os
import numpy as np
import h5py
import matplotlib.pyplot as plt
import scipy
from PIL import Image
from scipy import ndimage
import torch
import torch.nn as nn
from cnn_utils import *
from torch import nn,optim
from torch.utils.data import DataLoader,Dataset, WeightedRandomSampler
from torchvision import transforms
from ClassicNetwork.ResNet import ResNet50
from torchmetrics import ConfusionMatrix

PyTorch Version:  1.12.1
Torchvision Version:  0.13.1


In [2]:
np.random.seed(1)
torch.manual_seed(1)
batch_size = 144
learning_rate = 0.009
num_epocher = 30
pre_epoch = 0

In [3]:
rmb_label = dict()
class_index = 0
for i in range(10):
    rmb_label[str(i)] = class_index
    class_index += 1

for i in range (26):
    rmb_label[chr(65 + i)] = class_index
    class_index += 1
print(rmb_label)

{'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, 'A': 10, 'B': 11, 'C': 12, 'D': 13, 'E': 14, 'F': 15, 'G': 16, 'H': 17, 'I': 18, 'J': 19, 'K': 20, 'L': 21, 'M': 22, 'N': 23, 'O': 24, 'P': 25, 'Q': 26, 'R': 27, 'S': 28, 'T': 29, 'U': 30, 'V': 31, 'W': 32, 'X': 33, 'Y': 34, 'Z': 35}


In [4]:
class MyData(Dataset):
    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        counts = [0] * 36
        # data_dir 是训练集、验证集或者测试集的路径
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            # dirs ['1', '100']
            for sub_dir in dirs:
                # 文件列表
                img_names = os.listdir(os.path.join(root, sub_dir))
                # 取出 jpg 结尾的文件
                img_names = list(filter(lambda x: x.endswith('.png'), img_names))
                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    # 图片的绝对路径
                    path_img = os.path.join(root, sub_dir, img_name)
                    # 标签，这里需要映射为 0、1 两个类别
                    #print(path_img)
                    label = rmb_label[sub_dir]
                    # 保存在 data_info 变量中
                    data_info.append((path_img, int(label)))
                    counts[int(label)] = counts[int(label)] + 1
        return data_info, counts

    def __init__(self, data_dir, transform=None):
        """
        rmb面额分类任务的Dataset
        :param data_dir: str, 数据集所在路径
        :param transform: torch.transform，数据预处理
        """
        # data_info存储所有图片路径和标签，在DataLoader中通过index读取样本
        self.data_info, self.counts = self.get_img_info(data_dir)
        self.transform = transform
    
    def __len__(self):
        return len(self.data_info)

    def __getitem__(self, index):
        # 通过 index 读取样本
        path_img, label = self.data_info[index]
        # 注意这里需要 convert('RGB')
        img = Image.open(path_img).convert('L')     # 0~255
        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform，转为tensor等等
        # 返回是样本和标签
        return img, label

In [5]:
train_dir = './OCR_Image1/Train_Set'
valid_dir = './OCR_Image1/Valid_Set'
train_data = MyData(data_dir=train_dir, transform=transforms.Compose(
    [transforms.Resize([48, 36]), transforms.ToTensor()]))
valid_data = MyData(data_dir=valid_dir, transform=transforms.Compose(
    [transforms.Resize([48, 36]), transforms.ToTensor()]))

In [6]:
# 构建DataLoder
# 其中训练集设置 shuffle=True，表示每个 Epoch 都打乱样本
print(f"batch_size: {batch_size}")
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
print(len(train_loader))
valid_loader = DataLoader(dataset=valid_data, batch_size=batch_size, shuffle=False)

batch_size: 144
21


In [7]:
def updateClassCorrectCount(inputLabel, predLabel, labelTotoal, labelCorrect):
    index = 0
    for v in inputLabel:
        label = v.item()
        if label in labelTotoal:
            labelTotoal[label] += 1
        else:
            labelTotoal[label] = 1

        if (label == predLabel[index].item()):
            if label in labelCorrect:
                labelCorrect[label] += 1
            else:
                labelCorrect[label] = 1
        index += 1

def printLabelAcc(labelTotoal, labelCorrect):
    for label in labelTotoal:
        correctCount = 0
        if label in labelCorrect:
            correctCount = labelCorrect[label]
        acc = correctCount / labelTotoal[label]
        print(f"label {label}, acc {acc}")

In [8]:
device = 'cpu'

def test():
    model.eval()    #需要说明是否模型测试
    eval_loss = 0
    eval_acc = 0
    labelTotoalG = dict()
    labelCorrectG = dict()
    target = torch.tensor([], dtype=torch.int32)
    preds = torch.tensor([], dtype=torch.int32)
    for data in valid_loader:
        img, label = data
        img = img.float().to(device)
        label = label.long().to(device)
        out = model(img)    #前向算法
        loss = criterion(out,label) #计算loss
        eval_loss += loss.item() * label.size(0)    #total loss
        _,pred = torch.max(out,1)   #预测结果
        num_correct = (pred == label).sum() #正确结果
        
        target = torch.cat((target, label), -1)
        preds = torch.cat((preds, pred), -1)

        updateClassCorrectCount(label, pred, labelTotoalG, labelCorrectG)
        eval_acc += num_correct.item()  #正确结果总数

    print('Test Loss:{:.6f}, Acc: {:.6f}'
          .format(eval_loss/ (len(valid_data)), eval_acc * 1.0/(len(valid_data))))
    printLabelAcc(labelTotoalG, labelCorrectG)

    confmat = ConfusionMatrix(num_classes=36)
    confu = confmat(preds, target)

    torch.set_printoptions(threshold=100_100)
    print("confusion matrix:")
    print(confu)

    acc = eval_acc * 1.0/(len(valid_data))
    return acc

In [9]:
##### import model
# model = ResModel(6)
model = ResNet50(num_classes=len(rmb_label), imgsz = 64)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.8)

In [10]:
last = 0.90
for epoch in range(pre_epoch, num_epocher):
    model.train()
    running_loss = 0.0
    running_acc = 0.0
    for i, data in enumerate(train_loader, 1):
        img, label = data
        print(f"data length {len(label)}")
        #print(img)
        #print(label)
        img = img.float().to(device)
        label = label.long().to(device)
        #前向传播
        out = model(img)
        loss = criterion(out,label) #loss
        running_loss += loss.item() * label.size(0)
        _,pred = torch.max(out,1)   #预测结果
        num_correct = (pred == label).sum() #正确结果的数量
        running_acc += num_correct.item()   #正确结果的总数
        
        optimizer.zero_grad()   #梯度清零
        loss.backward() #后向传播计算梯度
        optimizer.step()    #利用梯度更新W，b参数
    #打印一个循环后，训练集合上的loss和正确率
    if (epoch+1) % 1 == 0:
        print('Train{} epoch, Loss: {:.6f}, Acc: {:.6f}'.format(epoch+1, running_loss / (len(train_data)),
                                                               running_acc / (len(train_data))))
        now = test()
        print(f"Now: {now}, last: {last}")

    ## save model
    # if epoch == 0:
    #     state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
    #     torch.save(state, './Resnet'+str(int(now*10000)/100)+'epoch'+str(epoch)+'.pt')
    #     last = now
    
    if now > last:
        state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
        torch.save(state, './Results/OCR_Gray_Resnet_' + str(int(now*10000)/100) + '_epoch_'+ str(epoch) + '_v7.pt')
        last = now

        example = torch.rand(1, 1, 36, 48)
        traced_script_module = torch.jit.trace(model, example)
        traced_script_module.save("./Results/OCR_Gray_traced_resnet_model_v7.pt")

        if last >= 0.93:
            print("complete tranining")
            break


data length 144
data length 144
data length 144
data length 144
data length 144
data length 144
data length 144
data length 144
data length 144
data length 144
data length 144
data length 144
data length 144
data length 144
data length 144
data length 144
data length 144
data length 144
data length 144
data length 144
data length 17
Train1 epoch, Loss: 4.497461, Acc: 0.115982
Test Loss:13.104855, Acc: 0.081081
label 27, acc 0.0
label 30, acc 0.0
label 9, acc 0.0
label 0, acc 0.013986013986013986
label 7, acc 0.0
label 18, acc 0.0
label 23, acc 0.0
label 16, acc 0.0
label 6, acc 0.18604651162790697
label 35, acc 0.0
label 1, acc 0.0
label 8, acc 0.0
label 29, acc 0.0
label 28, acc 0.0
label 10, acc 0.0
label 15, acc 0.0
label 24, acc 0.0
label 17, acc 0.0
label 22, acc 0.0
label 19, acc 0.0
label 12, acc 0.0
label 13, acc 0.0
label 31, acc 0.0
label 26, acc 0.0
label 4, acc 0.0
label 33, acc 0.0
label 3, acc 0.0
label 14, acc 0.0
label 11, acc 0.0
label 20, acc 0.0
label 21, acc 0.0
lab