In [1]:
#有了前两个模型的经验，这个部分的代码就很简单了
#首先是加载processor，它负责送进模型的数据的格式
import torch
import random
from transformers import AutoImageProcessor, TrOCRProcessor

#文字定位detr-resnet-50模型的processor,
processor_location = AutoImageProcessor.from_pretrained(
    'processor/facebook/detr-resnet-50',
    size={
        'longest_edge': 800,
        'shortest_edge': 800
    })

#文字定位trocr-base-handwritten模型的processor,
processor_recognition = TrOCRProcessor.from_pretrained(
    'processor/microsoft/trocr-base-handwritten')

processor_location, processor_recognition

(DetrImageProcessor {
   "do_convert_annotations": true,
   "do_normalize": true,
   "do_pad": true,
   "do_rescale": true,
   "do_resize": true,
   "format": "coco_detection",
   "image_mean": [
     0.485,
     0.456,
     0.406
   ],
   "image_processor_type": "DetrImageProcessor",
   "image_std": [
     0.229,
     0.224,
     0.225
   ],
   "resample": 2,
   "rescale_factor": 0.00392156862745098,
   "size": {
     "longest_edge": 800,
     "shortest_edge": 800
   }
 },
 TrOCRProcessor:
 - image_processor: ViTImageProcessor {
   "do_normalize": true,
   "do_rescale": true,
   "do_resize": true,
   "image_mean": [
     0.5,
     0.5,
     0.5
   ],
   "image_processor_type": "ViTImageProcessor",
   "image_std": [
     0.5,
     0.5,
     0.5
   ],
   "processor_class": "TrOCRProcessor",
   "resample": 2,
   "rescale_factor": 0.00392156862745098,
   "size": {
     "height": 384,
     "width": 384
   }
 }
 
 - tokenizer: RobertaTokenizerFast(name_or_path='processor/microsoft/trocr-bas

In [2]:
#加载之前训练好的文字定位模型
class Model(torch.nn.Module):

    def __init__(self):
        super().__init__()
        
    def forward(self, pixel_values, pixel_mask):
        last_hidden_state = self.model(pixel_values=pixel_values,
                                       pixel_mask=pixel_mask).last_hidden_state

        class_pred = self.class_labels_classifier(last_hidden_state)
        box_pred = self.bbox_predictor(last_hidden_state).sigmoid()

        return class_pred, box_pred


model_location = torch.load('model/文字定位.model').to('cuda')

In [3]:
#加载之前训练好的文字识别模型
class Model(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, pixel_values):
        last_hidden_state = self.encoder(pixel_values).last_hidden_state

        input_ids = torch.full([len(pixel_values), 1], 
                               processor_recognition.tokenizer.cls_token_id).to('cuda')
        for i in range(127):
            logits = self.decoder(input_ids=input_ids, 
                                  encoder_hidden_states=last_hidden_state).logits
            logits = logits.argmax(2)[:, -1].unsqueeze(1)
            input_ids = torch.cat([input_ids, logits], 1)

        return input_ids
        
model_recognition = torch.load('model/文字识别.model').to('cuda')

In [4]:
#准备数据
from datasets import load_from_disk

dataset = load_from_disk('dataset/data')['train']
dataset = dataset.select(range(500)).remove_columns(['ocr'])

def f(data):
    data['image'] = data['image'].resize([800, 800])
    return data

dataset = dataset.map(f)

dataset, dataset[0]

(Dataset({
     features: ['image'],
     num_rows: 500
 }),
 {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=800x800>})

In [5]:
%run 2.文字定位Loss函数.ipynb

In [6]:
#文字定位裁剪函数
def location(image):
    #数据先进文字定位processor处理
    data = processor_location(image, return_tensors='pt').to('cuda')
    #冻结参数以后，数据进入文字定位模型处理
    with torch.no_grad():
        class_pred, box_pred = model_location(data['pixel_values'], data['pixel_mask'])
    
    class_pred = class_pred.argmax(dim=2)
    box_pred = box_pred[class_pred <= 7][:16]
    class_pred = class_pred[class_pred <= 7][:16]
    
    #matcher函数就是把文字定位模型detr-resnet-50输出的x（左上横坐标），y（左上纵坐标），w（宽度），h（高度），转变为四个角的坐标
    #注意！上述的坐标都是比例坐标，宽高也是比例长度
    box_pred = matcher.xywh_to_x1y1x2y2(box_pred)
    #比例坐标转变为绝对值坐标
    box_pred *= 800

    return class_pred.tolist(), box_pred.tolist()

image = random.choice(dataset)['image']

class_pred, box_pred = location(image)
    
class_pred, box_pred

([6, 3, 2, 0, 1, 4, 7, 5, 2],
 [[68.11900329589844, 349.3677978515625, 520.3612060546875, 508.5596923828125],
  [63.87397003173828, 278.0183410644531, 187.609130859375, 339.68572998046875],
  [291.8573303222656, 167.39292907714844, 359.6639099121094, 244.5283203125],
  [48.233367919921875,
   82.97744750976562,
   223.44288635253906,
   161.0249786376953],
  [54.67331314086914,
   181.38064575195312,
   122.99226379394531,
   257.2747802734375],
  [240.29493713378906,
   269.9354248046875,
   311.1112976074219,
   329.4949035644531],
  [226.87704467773438, 571.681884765625, 715.569091796875, 669.05029296875],
  [348.2613220214844,
   264.5687255859375,
   416.0237731933594,
   322.81536865234375],
  [291.8280029296875,
   167.22036743164062,
   360.14593505859375,
   247.28839111328125]])

In [7]:
import PIL.Image


def recognition(image, box_pred):

    def pad(image):
        w, h = image.size
    
        ratio = 384 / max(w, h)
    
        w = int(ratio * w)
        h = int(ratio * h)
    
        image = image.resize([w, h])
    
        pad = PIL.Image.new('RGB', [384, 384], 'black')
        pad.paste(image, [0, 0])
    
        return pad

    def decode(input_ids):
        input_ids = input_ids.tolist()
    
        if processor_recognition.tokenizer.sep_token_id in input_ids:
            idx = input_ids.index(processor_recognition.tokenizer.sep_token_id) + 1
            input_ids = input_ids[:idx]
    
        return processor_recognition.tokenizer.decode(input_ids, skip_special_tokens=True)

    image = [pad(image.crop(box)) for box in box_pred]
    pixel_values = processor_recognition(image, return_tensors='pt').pixel_values.to('cuda')

    with torch.no_grad():
        logits = model_recognition(pixel_values)
    
    return [decode(i) for i in logits]


text = recognition(image, box_pred)

text

['澳门特别行政区澳门特别行政区', '2003', '汉', '', '男', '10', '', '27', '汉']

In [8]:
#切断而已，可以无视
1/0

ZeroDivisionError: division by zero

In [None]:
import PIL.ImageDraw
import PIL.ImageFont
from matplotlib import pyplot as plt
import numpy as np


def show(image, box_pred, class_pred):
    font = PIL.ImageFont.truetype('arial.ttf', size=50)
    draw = PIL.ImageDraw.Draw(image)
    for b, c in zip(box_pred, class_pred):
        draw.rectangle(b, outline='red', width=5)
        draw.text(b[:2], str(c), fill='red', font=font)
    
    plt.figure(figsize=[3, 3])
    plt.imshow(image)
    plt.show()


show(image, box_pred, class_pred)

In [9]:
for _ in range(5):
    image = random.choice(dataset)['image']
    class_pred, box_pred = location(image)
    text = recognition(image, box_pred)

    #show(image, box_pred, class_pred)

    class_name = ['姓名','性别','民族','年','月','日','住址','号码']
    for cls in range(8):
        if cls not in class_pred:
            continue
        idx = class_pred.index(cls)
        print(class_name[cls] + '->' + text[idx])

姓名->邭俊苋
性别->男
民族->汉
年->1998
月->9
日->31
住址->重庆市西阳圈家族自治县
号码->
姓名->厉苉茉
性别->女
民族->汉
年->2000
月->2
日->18
住址->内蒙古自治区和塎河区
号码->546253200002187064
姓名->冉银梋
性别->女
民族->汉
年->2004
月->9
日->17
住址->香�港特别行政区山区
号码->2575722004091762011
姓名->石苉凤
性别->男
民族->汉
年->2002
月->6
日->23
住址->
号码->598936200206237046
姓名->
性别->女
民族->汉
年->
月->10
日->4
住址->贵州省辊萉市陔西县
号码->615753199610404289
