In [1]:
# Sample code from the TorchVision 0.3 Object Detection Finetuning Tutorial
# http://pytorch.org/tutorials/intermediate/torchvision_tutorial.html

import os
import numpy as np
import torch
from PIL import Image
import cv2
import random
from pathlib import Path
import io 
import pandas as pd 
import tensorflow as tf
import xml.etree.ElementTree as ET

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

from .detection.engine import train_one_epoch, evaluate
from .detection import utils
from .detection import transforms as T
from . import cv2_util

from tensorboardX import SummaryWriter

class InsSegDataset(object):
    def __init__(self, transforms, tf_ds, class_name, dataset_len):
        self.transforms = transforms
        self.tf_ds = tf_ds
        self.class_name = class_name
        self.dataset_len = dataset_len
        # Create a dictionary describing the tfrecord features.
        self.image_feature_description = {
            'image': tf.io.FixedLenFeature([], tf.string),
            'label': tf.io.FixedLenFeature([], tf.string),
            'mask': tf.io.FixedLenFeature([], tf.string),
        }

    def __getitem__(self, idx):
        for i, e in enumerate(self.tf_ds):
            if i == idx:
                ee = _parse_image_function(e, self.image_feature_description)
                img = Image.open(io.BytesIO(ee['image'].numpy())).convert("RGB")
                mask = np.array(Image.open(io.BytesIO(ee['mask'].numpy())))
                label_root = ET.parse(io.BytesIO(ee['label'].numpy())).getroot()
                object_name = []
                for obj in label_root.iter('object'):
                    cls = obj.find('name').text
                    object_name.append(cls)
                labels = [self.class_name.index(j) for j in object_name]
                break

        # 删除mask中的白色部分
        mask[mask == 255] = 0
        # instances are encoded as different colors
        obj_ids = np.unique(mask)
        # first id is the background, so remove it
        obj_ids = obj_ids[1:]
        # split the color-encoded mask into a set of binary masks
        masks = mask == obj_ids[:, None, None]
        # get bounding box coordinates for each mask
        num_objs = len(obj_ids)
        boxes = []
        for i in range(num_objs):
            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return self.dataset_len


def get_model_instance_segmentation(num_classes, device, pth_path):

    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False)
    model_dict = model.state_dict()
    pretrained_dict = torch.load(pth_path, map_location=device)
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

    return model


def get_transform(train):
    
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
        
    return T.Compose(transforms)


def toTensor(img):
    
    assert type(img) == np.ndarray, 'the img type is {}, but ndarry expected'.format(type(img))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = torch.from_numpy(img.transpose((2, 0, 1)))
    
    return img.float().div(255)  # 255也可以改为256


def _parse_image_function(example_proto, image_feature_description):
    # Parse the input tf.train.Example proto using the dictionary above.
    return tf.io.parse_single_example(example_proto, image_feature_description)
    

def train(device, num_classes, val_size, batch_size, lr, optimizer, total_epoch, dataset_path, class_name, model_output_path, pth_path, tensorboard_dir, work_dir, tfrecord_dir):

    tfrecord_files = [str(f) for f in Path(tfrecord_dir).glob("*.tfrecord") if f.is_file() and f.stat().st_size > 0]
    tf_ds = tf.data.TFRecordDataset(filenames=tfrecord_files)
    # use our dataset and defined transformations
    dataset = InsSegDataset(get_transform(train=False), tf_ds, class_name, 40)
    dataset_test = InsSegDataset(get_transform(train=False), tf_ds, class_name, 40)
    # split the dataset in train and test set
    indices = torch.randperm(40).tolist()
    num_val = int(len(dataset) * val_size)
    dataset = torch.utils.data.Subset(dataset, indices[num_val:])  # 训练集张数
    dataset_test = torch.utils.data.Subset(dataset_test, indices[:num_val])  # 测试集张数
    # define training and validation data loaders
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=utils.collate_fn)
    data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, shuffle=False, num_workers=4, collate_fn=utils.collate_fn)
    # get the model using our helper function
    model = get_model_instance_segmentation(num_classes, device, pth_path)
    # move model to the right device
    model.to(device)

    # construct an optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    if optimizer == 'SGD':
        optimizer = torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=0.0005)  
    else:
        optimizer = torch.optim.Adam(params, lr=lr, weight_decay=0.0005)  
    # and a learning rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
    # let's train it for n epochs
    num_epochs = total_epoch  # 训练次数
    # 保存模型训练的tensorboard日志
    writer = SummaryWriter(log_dir=tensorboard_dir, flush_secs=60)
    # if torch.cuda.is_available():
    #     graph_inputs = torch.from_numpy(np.random.rand(64, 3, 7, 7)).type(torch.FloatTensor).cuda()
    # else:
    #     graph_inputs = torch.from_numpy(np.random.rand(64, 3, 7, 7)).type(torch.FloatTensor)
    # writer.add_graph(model, (graph_inputs,))

    for epoch in range(num_epochs):
        metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, 10)
        print("train loss: ", round(metric_logger.loss.global_avg, 3), '  epoch: ', epoch)
        train_loss = torch.tensor(round(metric_logger.loss.global_avg, 3)).to(device)
        writer.add_scalar('TrainLoss', train_loss, epoch) 
        # update the learning rate
        lr_scheduler.step()
        # evaluate on the test dataset
        coco_evaluator = evaluate(model, data_loader_test, device=device)
        bbox_miou = coco_evaluator.coco_eval['bbox'].stats.mean()
        segm_miou = coco_evaluator.coco_eval['segm'].stats.mean()
        total_miou = bbox_miou + segm_miou
        print("eval total miou: ", round(total_miou,3), '  epoch: ', epoch)
        total_miou = torch.tensor(round(total_miou,3)).to(device)
        writer.add_scalar('EvalTotalMiou', total_miou, epoch) 

    utils.save_on_master({'model': model.state_dict()}, os.path.join(model_output_path, 'model.pth'))
    utils.save_on_master({'model': model.state_dict()}, os.path.join(work_dir, 'pre_training_weights', 'model.pth'))
    print("That's it!")

    return metric_logger, coco_evaluator
