In [1]:
# Sample code from the TorchVision 0.3 Object Detection Finetuning Tutorial
# http://pytorch.org/tutorials/intermediate/torchvision_tutorial.html

import os
import numpy as np
import torch
from PIL import Image
import cv2
import random

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

from .detection.engine import train_one_epoch, evaluate
from .detection import utils
from .detection import transforms as T
from . import cv2_util

from tensorboardX import SummaryWriter

class PennFudanDataset(object):
    def __init__(self, root, transforms, class_name, work_dir):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to
        # ensure that they are aligned
        self.imgs = list(sorted(os.listdir(os.path.join(root, "image"))))
        self.masks = list(sorted(os.listdir(os.path.join(root, "mask"))))
# ------------------------------------------------#
        self.labels = list(sorted(os.listdir(os.path.join(work_dir, "ClassNames"))))
        self.class_name = class_name
        self.work_dir = work_dir
# ------------------------------------------------#

    def __getitem__(self, idx):
        # load images ad masks
        img_path = os.path.join(self.root, "image", self.imgs[idx])
        mask_path = os.path.join(self.root, "mask", self.masks[idx])
# ------------------------------------------------#
        label_path = os.path.join(self.work_dir, "ClassNames", self.labels[idx])
        with open(label_path) as f:
            content = f.read().split(',')[:-1]
            labels = [self.class_name.index(i) for i in content]
# ------------------------------------------------#
        img = Image.open(img_path).convert("RGB")
        # note that we haven't converted the mask to RGB,
        # because each color corresponds to a different instance
        # with 0 being background
        mask = Image.open(mask_path)

        mask = np.array(mask)
#------------------------------------------------#
        mask[mask == 255] = 0
# ------------------------------------------------#
        # instances are encoded as different colors
        obj_ids = np.unique(mask)
        # first id is the background, so remove it
        obj_ids = obj_ids[1:]

        # split the color-encoded mask into a set
        # of binary masks
        masks = mask == obj_ids[:, None, None]

        # get bounding box coordinates for each mask
        num_objs = len(obj_ids)
        boxes = []
        for i in range(num_objs):
            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # there is only one class
        # labels = torch.ones((num_objs,), dtype=torch.int64)
# ------------------------------------------------#
        labels = torch.as_tensor(labels, dtype=torch.int64)
# ------------------------------------------------#
        masks = torch.as_tensor(masks, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.imgs)


def get_model_instance_segmentation(num_classes, device, pth_path):
    # load an instance segmentation model pre-trained pre-trained on COCO
    # model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False)
    model_dict = model.state_dict()
    pretrained_dict = torch.load(pth_path, map_location=device)
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)



    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

    return model


def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)


def random_color():
    b = random.randint(0, 255)
    g = random.randint(0, 255)
    r = random.randint(0, 255)

    return (b, g, r)


def toTensor(img):
    assert type(img) == np.ndarray, 'the img type is {}, but ndarry expected'.format(type(img))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = torch.from_numpy(img.transpose((2, 0, 1)))
    return img.float().div(255)  # 255也可以改为256


def PredictImg(image, model, device):
    # img, _ = dataset_test[0]
    img = cv2.imread(image)
    result = img.copy()
    dst = img.copy()
    img = toTensor(img)

    names = {'0': 'background', '1': 'person'}
    # put the model in evaluati
    # on mode
    model.eval()
    with torch.no_grad():
        prediction = model([img.to(device)])

    boxes = prediction[0]['boxes']
    labels = prediction[0]['labels']
    scores = prediction[0]['scores']
    masks = prediction[0]['masks']

    m_bOK = False;
    for idx in range(boxes.shape[0]):
        if scores[idx] >= 0.8:
            m_bOK = True;
            color = random_color()
            mask = masks[idx, 0].mul(255).byte().cpu().numpy()
            thresh = mask
            contours, hierarchy = cv2_util.findContours(
                thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
            )
            cv2.drawContours(dst, contours, -1, color, -1)

            x1, y1, x2, y2 = boxes[idx][0], boxes[idx][1], boxes[idx][2], boxes[idx][3]
            name = names.get(str(labels[idx].item()))
            cv2.rectangle(result, (x1, y1), (x2, y2), color, thickness=2)
            cv2.putText(result, text=name, org=(x1, y1 + 10), fontFace=cv2.FONT_HERSHEY_SIMPLEX,
                        fontScale=0.5, thickness=1, lineType=cv2.LINE_AA, color=color)

            dst1 = cv2.addWeighted(result, 0.7, dst, 0.3, 0)

    if m_bOK:
        cv2.imshow('result', dst1)
        cv2.waitKey()
        cv2.destroyAllWindows()


def train(device, num_classes, val_size, batch_size, lr, optimizer, total_epoch, dataset_path, class_name, model_output_path, pth_path, tensorboard_dir, work_dir):
    # train on the GPU or on the CPU, if a GPU is not available
    # device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    # our dataset has two classes only - background and person
    # num_classes = 2 + 1  # 需要修改种类
    # use our dataset and defined transformations
    dataset = PennFudanDataset(dataset_path, get_transform(train=True), class_name, work_dir)
    dataset_test = PennFudanDataset(dataset_path, get_transform(train=False), class_name, work_dir)

    # split the dataset in train and test set
    indices = torch.randperm(len(dataset)).tolist()
    num_val = int(len(dataset) * val_size)
    dataset = torch.utils.data.Subset(dataset, indices[num_val:])  # 训练集张数
    dataset_test = torch.utils.data.Subset(dataset_test, indices[:num_val])  # 测试集张数

    # define training and validation data loaders
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=utils.collate_fn)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=1, shuffle=False, num_workers=4, collate_fn=utils.collate_fn)

    # get the model using our helper function
    model = get_model_instance_segmentation(num_classes, device, pth_path)

    # move model to the right device
    model.to(device)

    # construct an optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    if optimizer == 'SGD':
        optimizer = torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=0.0005)  
    else:
        optimizer = torch.optim.Adam(params, lr=lr, weight_decay=0.0005)  
    # and a learning rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

    model_without_ddp = model
    # let's train it for n epochs
    num_epochs = total_epoch  # 训练次数


    # 保存模型训练的tensorboard日志
    writer = SummaryWriter(log_dir=tensorboard_dir, flush_secs=60)
    # if torch.cuda.is_available():
    #     graph_inputs = torch.from_numpy(np.random.rand(1, 3, 512, 512)).type(torch.FloatTensor).cuda()
    # else:
    #     graph_inputs = torch.from_numpy(np.random.rand(1, 3, 512, 512)).type(torch.FloatTensor)
    # graph_inputs = torch.from_numpy(np.random.rand(1, 3, 512, 512)).type(torch.FloatTensor).to(device)
    # writer.add_graph(model, graph_inputs, verbose=False)


    for epoch in range(num_epochs):
        # train for one epoch, printing every 10 iterations
        metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, 10, writer)
        # update the learning rate
        lr_scheduler.step()
        # evaluate on the test dataset
        coco_evaluator = evaluate(model, data_loader_test, device=device)

    # utils.save_on_master({
    #         'model': model_without_ddp.state_dict(),
    #         'optimizer': optimizer.state_dict(),
    #         'lr_scheduler': lr_scheduler.state_dict()},
    #         os.path.join('./', 'model_{}.pth'.format(epoch)))

    # utils.save_on_master({'model': model_without_ddp.state_dict()}, os.path.join('./', 'model.pth'))
    utils.save_on_master({'model': model.state_dict()}, os.path.join(model_output_path, 'model.pth'))
    utils.save_on_master({'model': model.state_dict()}, os.path.join(work_dir, 'pre_training_weights', 'model.pth'))
    print("That's it!")
    # PredictImg("1.jpg",model,device)
    return metric_logger, coco_evaluator
