# Mask R-CNN - Train on Shapes Dataset


本笔记本展示了如何在您自己的数据集上训练 Mask R-CNN。 为了简单起见，我们使用形状（正方形、三角形和圆形）的合成数据集，以实现快速训练。 不过，您仍然需要 GPU，因为网络主干是 Resnet101，在 CPU 上训练速度太慢。 在 GPU 上，您可以在几分钟内开始获得不错的结果，并在不到一个小时内获得良好的结果。

*Shapes* 数据集的代码如下所示。 它即时生成图像，因此不需要下载任何数据。 它可以生成任意大小的图像，因此我们选择较小的图像大小来更快地训练。

In [None]:
import os
import sys
import random
import math
import re
import time
import numpy as np
import cv2
import matplotlib
import matplotlib.pyplot as plt

# Root directory of the project
ROOT_DIR = os.path.abspath("../../")

# Import Mask RCNN
sys.path.append(ROOT_DIR)  # To find local version of the library
from mrcnn.config import Config
from mrcnn import utils
import mrcnn.model as modellib
from mrcnn import visualize
from mrcnn.model import log

%matplotlib inline 

# Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs")

# Local path to trained weights file训练权重文件的本地路径
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
# Download COCO trained weights from Releases if needed
if not os.path.exists(COCO_MODEL_PATH):
    utils.download_trained_weights(COCO_MODEL_PATH)

## Configurations

In [None]:
class ShapesConfig(Config):
    """Configuration for training on the toy shapes dataset.
    Derives from the base Config class and overrides values specific
    to the toy shapes dataset.
    玩具形状数据集的训练配置。
    从基本 Config 类派生并覆盖特定于玩具形状数据集的值。
    """
    # Give the configuration a recognizable name
    NAME = "shapes"

    # Train on 1 GPU and 8 images per GPU. We can put multiple images on each
    # GPU because the images are small. Batch size is 8 (GPUs * images/GPU).
    # 在 1 个 GPU 上训练，每个 GPU 上训练 8 个图像。 我们可以在每个 GPU 上放置多个图像，因为图像很小。 批量大小为 8（GPU * 图像/GPU）。
    GPU_COUNT = 1
#     IMAGES_PER_GPU过大，显存会溢出，可以设置小一点
    IMAGES_PER_GPU = 8

    # Number of classes (including background)
#     数据集的类别数，第一类为bg，3（比如三角，圆，矩形）类，所以1+3
    NUM_CLASSES = 1 + 3  # background + 3 shapes

    # Use small images for faster training. Set the limits of the small side
    # the large side, and that determines the image shape.
    # 使用小图像进行更快的训练。 设置小边和大边的界限，这决定了图像的形状。
#     实际训练中修改为自己图片的尺寸，如：1280*800应改为：IMAGE_MIN_DIM = 800;IMAGE_MAX_DIM = 1280
    IMAGE_MIN_DIM = 128
    IMAGE_MAX_DIM = 128

    # Use smaller anchors because our image and objects are small
    # 使用较小的锚点，因为我们的图像和对象很小
#     根据自己情况设置anchor大小，如RPN_ANCHOR_SCALES = (8*6, 16*6, 32*6, 64*6, 128*6)
    RPN_ANCHOR_SCALES = (8, 16, 32, 64, 128)  # anchor side in pixels

    # Reduce training ROIs per image because the images are small and have
    # few objects. Aim to allow ROI sampling to pick 33% positive ROIs.
    # 减少每张图像的训练 ROI，因为图像很小且对象很少。 旨在让 ROI 采样能够选择 33% 的正 ROI。
    TRAIN_ROIS_PER_IMAGE = 32

    # Use a small epoch since the data is simple
    STEPS_PER_EPOCH = 100

    # use small validation steps since the epoch is small
    VALIDATION_STEPS = 5
    
config = ShapesConfig()
config.display()

## Notebook Preferences

In [None]:
def get_ax(rows=1, cols=1, size=8):
    """Return a Matplotlib Axes array to be used in
    all visualizations in the notebook. Provide a
    central point to control graph sizes.
    
    Change the default size attribute to control the size
    of rendered images
    
    返回要在笔记本中的所有可视化中使用的 Matplotlib Axes 数组。 提供一个中心点来控制图形大小。

    更改默认大小属性以控制渲染图像的大小
    """
    _, ax = plt.subplots(rows, cols, figsize=(size*cols, size*rows))
    return ax

## 在全局定义一个iter_num

In [2]:
iter_num=0

## Dataset

Create a synthetic dataset创建合成数据集

Extend the Dataset class and add a method to load the shapes dataset, `load_shapes()`, and override the following methods:
扩展 Dataset 类并添加一个方法来加载形状数据集，`load_shapes()`，并覆盖以下方法：

* load_image()
* load_mask()
* image_reference()

重写一个训练类

In [1]:
class DropDataset(util.Dataset):
    # 得到该图中有多少个实例（液滴？）
    def get_obj_index(self, image):
        n = np.max(image)
        return n
    
    # 解析labelme中得到的yaml文件，从而得到mask每一层对应的实例标签
    def from_yaml_get_class(self, image_id):
        info=self.image_info[image_id]
        with open(info['yaml_path'])as f:
            temp=yaml.load(f.read())
            labels=temp['label_names']
            del labels[0]
        return labels
    
    # 重新写draw_mask
    def draw_mask(self, num_obj, mask, image):
        info = self.image_info[image_id]
        for index in range(num_obj):
            for i in range(info['width']):
                for j in range(info['height']):
                    at_pixel = image.getpixel((i,j))
                    if at_pixel == index + 1:
                        mask[j,i,index] = 1
        return mask
    
    # 重新写load_shapes,里面包含自己的自己的类别（我的是box、column、package、fruitg四类)
    # 并在self.image_info信息中添加了path、mask_path、yaml_path
    def load_shapes(self,count,height,width,img_floder,mask_floder,imglist,dataset_root_path)
        """Generate the requested number of synthetic images.
        count:number of images to generate.
        height,width:the size of the generated images.
        """
        # Add classes
        self.add_class("shapes",1,"box")
        self.add_class("shapes",2,"column")
        self.add_class("shapes",3,"package")
        self.add_class("shapes",4,"fruit")
        for i in range(count):
            filestr=imglist[i].split(".")[O]
            filestr=filestr.split("")[1]
            mask_path=mask_floder+"/"+ filestr +".png"
            yaml_path=dataset_root_path+"total/rgb_"+filestr+"_json/info.yaml"
            self.add_image("shapes",image_id=i,path=img_floder + "/"+imglist[i],
                           width=width,height=height,mask_path=mask_path,yaml_path=yaml_path)
    
    # 重写load_mask
    def load_mask(self, image_id):
        """Generate instance masks for shapes of the given image ID.
        """
        global iter_num
        info = self.image_info[image_id]
        count = 1# number of object
        img = Image.open(info['mask_path'])
        num_obj = self.get_obj_index(img)
        mask=np.zeros([info['height'],info['width'],num_obj],dtype=np.uint8)
        mask=self.draw_mask(num_obj, mask, img)
        occlusion = np.logical_not(mask[:,:,-1]).astype(np.uint8)
        for i in range(count-2, -1, -1):
            mask[:,:,i]=mask[:,:,i] * occlusion
            occlusion=np.logical_and(occlusion, np.logical_not(mask[:,:,i]))
        labels=[]
        labels=self.from_yaml_get_class(image_id)
        labels_form=[]
        for i in range(len(labels)):
            if labels[i].find("box")!=-1:
                # print "box"
                labels_form.append("box")
            elif labels[i].find("column")!=-1:
                # print "column"
                labels_form.append("column")
            elif labels[i].find("package")!=-1:
                # print "package"
                labels_form.append("package")
            elif labels[i].find("fruit")!=-1:
                # print "fruit"
                labels_form.append("fruit")
        class_ids = np.array([self.class_names.index(s) for s in labels_form])
        return mask,class_ids.astype(np.int32)

## 代码主体修改

In [None]:
# 基础设置
# dataset_root_path="/home/lijing/workspace_lj/fg_dateset/"
dataset_root_path="/dateset/"
img_floder = dataset_root_path + "rgb"
mask_floder = dataset_root_path + "mask"
#yaml_floder=dataset_root_path
imglist = listdir(img_floder)
count = len(imglist)
width = 1280
height = 800

# train与val数据集准备
# Training dataset
dataset_train=DrugDataset()
dataset_train.load_shapes(count,800,1280,img_floder,mask_floder,imglist,dataset_root_path)
dataset_train.prepare()

# Validation dataset
dataset_val = DrugDataset()
dataset_val.load_shapes(count,800,1280,img_floder,mask_floder,imglist,dataset_root_path)
dataset_val.prepare()