In [None]:
import numpy as np
import pandas as pd 
from bs4 import BeautifulSoup
import torchvision
from torchvision import transforms,datasets,models
import torch
from torch.utils.data import Dataset, DataLoader
#引入预训练模型
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
import matplotlib.patches as patches
import os 

In [None]:
#在 xml 中读取bbox
def generate_box(obj):
    xmin = int(obj.find('xmin').text)
    ymin = int(obj.find('ymin').text)
    xmax = int(obj.find('xmax').text)
    ymax = int(obj.find('ymax').text)
    return [xmin,ymin,xmax,ymax]

#在xml 读取label 二分类，为person
def generate_person_label(obj):
    #person 返回1 不是返回0
    return obj.find('name').text =='person'

def generate_target(image_id,file):
    with open(file,'r',encoding='utf-8') as f :
        data = f.read()
        soup  = BeautifulSoup(data,'lxml-xml')
        objs = soup.find_all('object')
        nums_objs = len(objs)
        boxes = []
        labels = []
        for i in objs:
            boxes.append(generate_box(i))
            labels.append(generate_person_label(i))
        #make data
        boxes = torch.as_tensor(boxes,dtype=torch.float32)
        labels = torch.as_tensor(labels,dtype=torch.int64)
        img_id = torch.tensor([image_id])
        target = {}
        target['boxes'] = boxes
        target['labels'] = labels
        target['image_id'] = img_id
        return target


#因为不对等 所以要把图片名字和xml文件名字对应起来
# labels = [img[:-4]+'.xml' for img in imgs]
labels = list(sorted(os.listdir('D:\\ML_data_sql\\human_detect\\annotations\\')))
imgs = [label[:-4]+'.jpg' for label in labels]
#使用一部分数据（0.2倍）
use_labels = labels[:int(len(labels)*0.2)]
use_imgs = imgs[:int(len(imgs)*0.2)]
test_imgs = imgs[int(len(imgs)*0.2):int(len(imgs)*0.21)]
test_labels = labels[int(len(labels)*0.2):int(len(labels)*0.21)]
# print(len(labels))
# print(len(imgs))
# print(imgs[0],labels[0])

In [None]:
#make dataset
class PersonDataset(object):
    def __init__(self,transforms,imgs,labels):
        self.transforms = transforms
        self.imgs = imgs
        self.labels = labels
        self.lens = len(imgs)
    
    def __getitem__(self,idx):
        #load images and masks
        img = Image.open('D:\\ML_data_sql\\human_detect\\JPEGImages\\'+self.imgs[idx])
        #label = generate_target(idx,'D:\\ML_data_sql\\human_detect\\annotations\\'+self.labels[idx])
        target = generate_target(idx,'D:\\ML_data_sql\\human_detect\\annotations\\'+self.labels[idx])
        #apply transforms
        if self.transforms is not None:
            img = self.transforms(img)
        return img,target
    
    def __len__(self):
        return len(self.imgs)

data_transform = transforms.Compose([
        transforms.ToTensor(), 
    ])   

In [None]:
from torch.utils.data.dataset import random_split
dataset = PersonDataset(data_transform,use_imgs,use_labels)
test_dataset =PersonDataset(data_transform,test_imgs,test_labels)
#因为dataset 有点大 爆显存 所以使用少量来训练
def collate_fn(batch):
    return tuple(zip(*batch))
traindata_loader = torch.utils.data.DataLoader(dataset, batch_size=4,collate_fn=collate_fn)
testdata_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4,collate_fn=collate_fn)

In [None]:
#build the model
def get_model_instance_segmentation(num_classes):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model
#load the model
model = get_model_instance_segmentation(2)


In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#训练前 数据查看
#model.to(device)
for imgs,annotations in traindata_loader:
    imgs = list(img.to(device) for img in imgs)
    annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
    print(annotations)
    #包含了bbox 和label 格式正确
    #model_outputs = model(imgs)
    break

In [None]:
#begin training
from tqdm.notebook import tqdm
import time 
epochs = 1

model.to(device)
#超参数设置，SGD 优化器，采用动量法衰减学习率，每次更新学习率为原来的0.9倍
params =  [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(model.parameters(),lr=0.005,momentum=0.9,weight_decay=0.0005)
len_dataloaer = len(traindata_loader)

for epoch in range(epochs):
    model.train()
    i=0
    epoch_loss = 0
    with tqdm(traindata_loader) as iterator:
        for imgs,annotations in iterator:
            i+=1
            imgs = list(img.to(device) for img in imgs)
            annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
            #print(len(imgs),len(annotations)
            #计算loss
            loss_dict = model([imgs[0]],[annotations[0]])
            losses = sum(loss for loss in loss_dict.values())
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            epoch_loss += losses
    #每10个epoch 输出一次loss
    if i%10==0:
        torch.save(model.state_dict(),'./model/model_'+str(epoch+1)+'.pth')
        print('epoch:',epoch+1,'i:',i,'loss:',epoch_loss)

In [None]:
#test
# plt_imgs = []
# plt_pred = []
# model.load_state_dict(torch.load('./model/model_begin.pt'))
# model.eval()
# model.to(device)
# for imgs, annotations in testdata_loader:
#         imgs = list(img.to(device) for img in imgs)
#         annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
#         plt_imgs.append(imgs)
        

In [None]:
# #限于显存原因，只能够采用少量数据来训练，并且只能只用每批batch中的第一个来训练
# model.load_state_dict(torch.load('./model/model_begin.pt'))
# input_img,labels = test_dataset[0]
# print(test_dataset.lens)
# input_img = input_img.unsqueeze(0)
# print(input_img.shape)
# pred = model(input_img.to(device))[0]
# print(pred)
# pred_box = pred['boxes'].detach().cpu().numpy()
# print(pred_box)
# real = labels['boxes'].detach().cpu().numpy()

In [None]:
# def plot_image(img_tensor,pred):
#     fig,ax = plt.subplots(1)
#     img = img_tensor.cpu().data
#     # Display the image
#     ax.imshow(img.permute(1, 2, 0))
#     xmin,ymin,xmax,ymax = pred[0][1],pred[0][1],pred[0][2],pred[0][3]
#     #xmin, ymin, xmax, ymax = [427.7730, 534.9896, 638.9386, 720.0000]
#     # Create a Rectangle patch
#     rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=2,edgecolor='white',facecolor='none') 
#     # Add the patch to the Axes
#     ax.add_patch(rect)
#     plt.show()
# plot_image(input_img.squeeze(0),pred_box)
# plot_image(input_img.squeeze(0),real)

In [None]:
#evaluate
#look up to predict_img.py
