In [1]:
#https://github.com/phamquiluan/PubLayNet/blob/master/maskrcnn/infer.py
import os
import sys
import random
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.transforms import transforms

import cv2
import numpy as np

In [2]:
seed = 1234
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
CATEGORIES2LABELS = {
    0: "bg",
    1: "text",
    2: "title",
    3: "list",
    4: "table",
    5: "figure"
}

In [4]:
def get_instance_segmentation_model(num_classes):
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256

    model.roi_heads.mask_predictor = MaskRCNNPredictor(
        in_features_mask,
        hidden_layer,
        num_classes)
    
    model.training = False
    model.roi_heads.box_predictor.training = False
    model.roi_heads.mask_predictor.conv5_mask.training = False
    model.roi_heads.mask_predictor.training = False
    return model

#seed = 1234
#random.seed(seed)
#torch.manual_seed(seed)
#torch.cuda.manual_seed_all(seed)
#torch.backends.cudnn.deterministic = True
#torch.backends.cudnn.benchmark = False

In [5]:
checkpoint_path = r'D:\MachineLearning\Models\phamquiluan-PubLayNet\model_196000.pth'
num_classes = 6
model = get_instance_segmentation_model(num_classes)
#model.cuda()

assert os.path.exists(checkpoint_path)
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model'])
#model.eval()

<All keys matched successfully>

In [6]:
def overlay_mask(image, mask, alpha=0.5):
    c = (np.random.random((1, 3)) * 153 + 102).tolist()[0]
 
    mask = np.dstack([mask.astype(np.uint8)] * 3)
    mask = cv2.threshold(mask, 127.5, 255, cv2.THRESH_BINARY)[1]
    inv_mask = 255 - mask

    overlay = image.copy()
    overlay = np.minimum(overlay, inv_mask) 

    color_mask = (mask.astype(np.bool) * c).astype(np.uint8)
    overlay = np.maximum(overlay, color_mask).astype(np.uint8) 

    image = cv2.addWeighted(image, alpha, overlay, 1 - alpha, 0)
    return image

def overlay_ann(image, mask, box, label, score, alpha=0.5):
    c = np.random.random((1, 3))
    mask_color = (c * 153 + 102).tolist()[0]
    text_color = (c * 183 + 72).tolist()[0]
 
    mask = np.dstack([mask.astype(np.uint8)] * 3)
    mask = cv2.threshold(mask, 127.5, 255, cv2.THRESH_BINARY)[1]
    inv_mask = 255 - mask

    overlay = image.copy()
    overlay = np.minimum(overlay, inv_mask) 

    color_mask = (mask.astype(np.bool) * mask_color).astype(np.uint8)
        
    overlay = np.maximum(overlay, color_mask).astype(np.uint8) 

    image = cv2.addWeighted(image, alpha, overlay, 1 - alpha, 0)

    # draw on color mask
    cv2.rectangle(
        image,
        (box[0], box[1]),
        (box[2], box[3]),
        mask_color, 1
    )

    (label_size_width, label_size_height), base_line = \
        cv2.getTextSize(
            "{}".format(label),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.3, 1
        )

    cv2.rectangle(
        image,
        (box[0], box[1] + 10),
        (box[0] + label_size_width, box[1] + 10 - label_size_height),
        (223, 128, 255),
        cv2.FILLED
    )

    cv2.putText(
        image,
        # "{}: {:.3f}".format(label, score),
        "{}".format(label),
        (box[0], box[1] + 10),
        cv2.FONT_HERSHEY_SIMPLEX,
        0.3, (0, 0, 0), 1
    )
 
    return image

In [None]:
#%timeit
image_path = r'D:\Datasets\Document Layout Analysis\PubLayNet\PMC5055614_00000.jpg'
assert os.path.exists(image_path)

image = cv2.imread(image_path)
#print(image)

rat = 1300 / image.shape[0]
image = cv2.resize(image, None, fx=rat, fy=rat)
#print(image)
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor()
])

image = transform(image)
#print('image=' + str(image[1][200]))

with torch.no_grad():
    prediction = model([image]) #.cuda()])

In [None]:
image = torch.squeeze(image, 0).permute(1, 2, 0).mul(255).numpy().astype(np.uint8)
#print(image)

print('prediction=' + str(len(prediction)))

for pred in prediction:
    print('mask=' + str(pred['masks'].shape))
    print('boxes=' + str(pred['boxes'].shape))
    print('labels=' + str(pred['labels'].shape))
    print('scores=' + str(pred['scores'].shape))
    for idx, mask in enumerate(pred['masks']):
        if pred['scores'][idx].item() < 0.7:
            continue

        m = mask[0].mul(255).byte().cpu().numpy()
        box = list(map(int, pred["boxes"][idx].tolist()))
        label = CATEGORIES2LABELS[pred["labels"][idx].item()]

        score = pred["scores"][idx].item()

        #image = overlay_mask(image, m)
        image = overlay_ann(image, m, box, label, score)

cv2.imwrite('test.jpg', image)
#show(image)

In [None]:
print('mask=' + str(prediction[0]['masks'].shape))
print('boxes=\n' + str(prediction[0]['boxes']))
print('labels=\n' + str(prediction[0]['labels']))
print('scores=\n' + str(prediction[0]['scores']))

In [None]:
image = torch.randn(3, 1300, 1300, requires_grad=True)

#image = cv2.imread(image_path)
#rat = 1300 / image.shape[0]
#rat2 = 1300 / image.shape[1] 
#image = cv2.resize(image, None, fx=rat2, fy=rat)
#print(image.shape)
#transform = transforms.Compose([
#    transforms.ToPILImage(),
#    transforms.ToTensor()
#])
#image = transform(image)

torch.onnx.export(model, [image], 'model_196000.1.onnx', opset_version=12,
                  input_names = ['image'],
                  output_names = ['boxes', 'labels', 'scores', 'masks'],
                  dynamic_axes=
                  { 
                      'masks' : {0 : 'pred'},
                      'boxes' : {0 : 'pred'},
                      'labels' : {0 : 'pred'},
                      'scores' : {0 : 'pred'},
                  })