In [3]:
import torch
import random
from transformers import AutoImageProcessor

#AutoImageProcessor.from_pretrained()会自动根据detr-resnet-50模型文件夹下的preprocessor_config.json所记录的格式自动裁剪图片，更具体的图片格式可以看processor/facebook/detr-resnet-50/preprocessor_config.json
#这里重写了size部分的数据，考虑到本次训练的数据集的图片的大小也就800*800，所以800*800够用了，
processor = AutoImageProcessor.from_pretrained(
    'processor/facebook/detr-resnet-50',
    size={
        'longest_edge': 800,
        'shortest_edge': 800
    })

#查看detr-resnet-50模型所需数据格式，
processor

DetrImageProcessor {
  "do_normalize": true,
  "do_pad": true,
  "do_rescale": true,
  "do_resize": true,
  "format": "coco_detection",
  "image_mean": [
    0.485,
    0.456,
    0.406
  ],
  "image_processor_type": "DetrImageProcessor",
  "image_std": [
    0.229,
    0.224,
    0.225
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "longest_edge": 800,
    "shortest_edge": 800
  }
}

In [4]:
from datasets import load_from_disk

#读取训练数据
dataset = load_from_disk('dataset/data')['train']

#数据格式化
def f(data):
    box = [i['box'] for i in data['ocr']]
    cls = [i['cls'] for i in data['ocr']]
    return {'image': data['image'], 'box': box, 'cls': cls}

#数据格式化
#去除ocr部分
dataset = dataset.map(f, remove_columns=['ocr'])

In [5]:
#测试上一步的结果

dataset, dataset[0]

(Dataset({
     features: ['image', 'box', 'cls'],
     num_rows: 8500
 }),
 {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=250x250>,
  'box': [[49.713941536006374,
    31.017112564616305,
    84.34739644797423,
    56.823359827267545],
   [48.396413600453215,
    60.27906174757388,
    62.12991846361899,
    85.03621709441235],
   [95.53531805083459,
    65.27309398461871,
    109.2688229140004,
    90.03024933145721],
   [47.28158227036977,
    91.88196686516804,
    71.98526505536066,
    111.97578766994951],
   [88.4345940921313,
    96.24183627846115,
    97.42530872699506,
    114.67097967089434],
   [104.89579882083592,
    97.9857840437784,
    118.37593292716456,
    116.89054955402537],
   [45.00124545883547,
    119.9734380808073,
    136.61390433736082,
    171.6230872605421],
   [71.85146106045553,
    199.14604972228662,
    169.98367684877454,
    229.34932105266336]],
  'cls': [0, 1, 2, 3, 4, 5, 6, 7]})

In [6]:
#数据校对函数，在之后的torch.utils.data.DataLoader()会使用到。每次放进dataloader的一批数据都会做下述操作
def f(data):
    #依次获取对应数据类
    image = [i['image'].convert('RGB') for i in data]
    box = [i['box'] for i in data]
    cls = [i['cls'] for i in data]

    annotations = []
    
    #整理数据为detr-resnet-50所需数据格式
    #用zip打包成一个元组，少写一个for
    for bs, cs in zip(box, cls):
        for b in bs:
            #得到定位框的长度
            b[2] -= b[0]
            #得到定位框的高度
            b[3] -= b[1]
        
        #此时b的数据格式为
        anno = [{
            #类型
            'category_id': c,
            #
            'area': 0,
            #定位框
            'bbox': b
        } for b, c in zip(bs, cs)]
        
        annotations.append({'image_id': 0, 'annotations': anno})

    #对数据进行裁剪
    data = processor(image, annotations=annotations, return_tensors='pt')

    #放进GPU，可以在此处放，也可以在dataloader()调用以后再放
    pixel_values = data['pixel_values'].to('cuda')
    pixel_mask = data['pixel_mask'].to('cuda')

    
    size = []
    class_target = []
    box_target = []
    
    #获取对应数据类
    for label in data['labels']:
        
        
        size.append(label['size'].tolist())
        
        #放进GPU
        class_target.append(label['class_labels'].to('cuda'))
        box_target.append(label['boxes'].to('cuda'))

    return pixel_values, pixel_mask, class_target, box_target, size

#collate_fn为校对函数，让每批数据在进入loader之前都调用一次。在此处的用法是每次进入dataloader之前都对数据进行一次格式调整
loader = torch.utils.data.DataLoader(dataset,
                                     batch_size=8,
                                     shuffle=True,
                                     drop_last=True,
                                     collate_fn=f)

pixel_values, pixel_mask, class_target, box_target, size = next(iter(loader))

#测试
print(len(loader))
print(size)

pixel_values.shape, pixel_mask.shape, class_target[0], box_target[0]

The `max_size` parameter is deprecated and will be removed in v4.26. Please specify in `size['longest_edge'] instead`.


1062
[[800, 800], [800, 800], [800, 800], [800, 800], [800, 800], [800, 800], [800, 800], [800, 800]]


(torch.Size([8, 3, 800, 800]),
 torch.Size([8, 800, 800]),
 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='cuda:0'),
 tensor([[0.2024, 0.2337, 0.2205, 0.1005],
         [0.1217, 0.3285, 0.0913, 0.0894],
         [0.4138, 0.3624, 0.0913, 0.0894],
         [0.1412, 0.4309, 0.1577, 0.0774],
         [0.3476, 0.4549, 0.0603, 0.0661],
         [0.4774, 0.4699, 0.0603, 0.0661],
         [0.3238, 0.5931, 0.5788, 0.2091],
         [0.4934, 0.8396, 0.6134, 0.1382]], device='cuda:0'))

In [8]:
#定义模型
class Model(torch.nn.Module):

    def __init__(self):
        super().__init__()

        #前半部分为预训练模型detr-resnet-50
        from transformers import DetrModel, DetrConfig
        #读取detr-resnet-50的config.json文件，也就是detr-resnet-50的配置信息
        config = DetrConfig.from_pretrained('processor/facebook/detr-resnet-50')
        #用配置信息初始化detr-resnet-50
        self.model = DetrModel(config)
        
        #后半部分，也就是我们自定义的部分
        
        #类标签分类器
        #本次任务只有八个分类，所以92的输出有多，可以改为8输出
        self.class_labels_classifier = torch.nn.Linear(256, 92)
        '''
        #改为8输出
        self.class_labels_classifier = torch.nn.Sequential(torch.nn.Linear(256, 92),
                                                  torch.nn.ReLU(),
                                                  torch.nn.Linear(256, 8))
        
        '''
        
        

        #定位框预测器
        self.bbox_predictor = torch.nn.Sequential(torch.nn.Linear(256, 256),
                                                  torch.nn.ReLU(),
                                                  torch.nn.Linear(256, 256),
                                                  torch.nn.ReLU(),
                                                  torch.nn.Linear(256, 4))

        #加载参数
        from transformers import AutoModelForObjectDetection
        parameters = AutoModelForObjectDetection.from_pretrained(
            'processor/facebook/detr-resnet-50', ignore_mismatched_sizes=True)

        self.model.load_state_dict(parameters.model.state_dict())

        #使自定义的网络模型类标签分类器，也就是(256, 92)的参数与detr-resnet-50的class_labels_classifier部分的第1层一致
        self.class_labels_classifier.load_state_dict(
            parameters.class_labels_classifier.state_dict())

        #使自定义的网络模型定位框预测器的第一层，也就是torch.nn.Linear(256, 256)的参数与detr-resnet-50的bbox_predictor部分的第1层一致
        self.bbox_predictor[0].load_state_dict(
            parameters.bbox_predictor.layers[0].state_dict())

        #使自定义的网络模型定位框预测器的第2层，也就是torch.nn.Linear(256, 256)的参数与detr-resnet-50的bbox_predictor的第2层一致
        self.bbox_predictor[2].load_state_dict(
            parameters.bbox_predictor.layers[1].state_dict())

        #使自定义的网络模型定位框预测器的第2层，也就是torch.nn.Linear(256, 4)的参数与detr-resnet-50的bbox_predictor的第3层一致
        self.bbox_predictor[4].load_state_dict(
            parameters.bbox_predictor.layers[2].state_dict())

        del parameters

        #模型置入GPU，以下两行可以移出模型定义，以便在外部更灵活地定义要使用的计算设备
        self.to('cuda')
        self.train()

    def forward(self, pixel_values, pixel_mask):
        #detr-resnet-50
        #总而言之，detr-resnet-50的作用是得到图片的潜在向量，相当于学习图片所含有的信息，由于模型的特性，detr-resnet-50会更偏向于学习各个对象的位置和类别
        last_hidden_state = self.model(pixel_values=pixel_values,
                                       pixel_mask=pixel_mask).last_hidden_state
        
        #类标签分类
        class_pred = self.class_labels_classifier(last_hidden_state)
        #定位框预测
        box_pred = self.bbox_predictor(last_hidden_state).sigmoid()

        return class_pred, box_pred


model = Model()

#这里只用于演示，所以冻结参数，以便快速查看输出的格式
with torch.no_grad():
    #得到
    class_pred, box_pred = model(pixel_values, pixel_mask)

class_pred.shape, box_pred.shape

ProxyError: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /timm/resnet50.a1_in1k/resolve/main/model.safetensors (Caused by ProxyError('Unable to connect to proxy', SSLError(SSLZeroReturnError(6, 'TLS/SSL connection has been closed (EOF) (_ssl.c:1131)'))))"), '(Request ID: 2b5c3537-d2fa-46ef-87bd-76fedc4927af)')

In [11]:
#加载loss函数
%run 2.文字定位Loss函数.ipynb

criterion.to('cuda')

In [12]:
#效果展示

import PIL.ImageDraw
import PIL.ImageFont
from matplotlib import pyplot as plt
import numpy as np


def show(image, box_target, class_target, box_pred, class_pred, size):
    font = PIL.ImageFont.truetype('arial.ttf', size=50)

    image = image - image.min()
    image = image / image.max() * 255
    image = image.permute(1, 2, 0)
    image = np.uint8(image.to('cpu').numpy())
    image = PIL.Image.fromarray(image, 'RGB')

    box_target = matcher.xywh_to_x1y1x2y2(box_target)
    box_target[:, 0] *= size[1]
    box_target[:, 1] *= size[0]
    box_target[:, 2] *= size[1]
    box_target[:, 3] *= size[0]

    image_target = image.copy()
    draw = PIL.ImageDraw.Draw(image_target)
    for b, c in zip(box_target, class_target.tolist()):
        draw.rectangle(b.tolist(), outline='red', width=5)
        draw.text(b.tolist()[:2], str(c), fill='red', font=font)

    class_pred = class_pred.argmax(1)
    box_pred = box_pred[class_pred <= 7]
    class_pred = class_pred[class_pred <= 7]
    box_pred = matcher.xywh_to_x1y1x2y2(box_pred)
    box_pred[:, 0] *= size[1]
    box_pred[:, 1] *= size[0]
    box_pred[:, 2] *= size[1]
    box_pred[:, 3] *= size[0]

    plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(image_target)

    image_pred = image.copy()
    draw = PIL.ImageDraw.Draw(image_pred)
    for b, c in zip(box_pred, class_pred.tolist()):
        draw.rectangle(b.tolist(), outline='red', width=5)
        draw.text(b.tolist()[:2], str(c), fill='red', font=font)

    plt.subplot(1, 2, 2)
    plt.imshow(image_pred)
    plt.show()


show(pixel_values[0], box_target[0], class_target[0], box_pred[0],
     class_pred[0], size[0])