-
Notifications
You must be signed in to change notification settings - Fork 0
/
find_mask.py
87 lines (56 loc) · 1.7 KB
/
find_mask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# -*- coding: utf-8 -*-
"""
Created on Tue May 18 21:44:51 2021
@author: prajw
"""
import json
import glob
import cv2
import warnings
warnings.filterwarnings("ignore")
from mrcnn.config import Config
from utils import to_rle
import mrcnn.model as modellib
IMAGE_SIZE = 512
NUM_CATS = 46
class FashionConfig(Config):
NAME = "fashion"
NUM_CLASSES = NUM_CATS + 1
GPU_COUNT = 1
IMAGES_PER_GPU = 4
BACKBONE = 'resnet50'
IMAGE_MIN_DIM = IMAGE_SIZE
IMAGE_MAX_DIM = IMAGE_SIZE
IMAGE_RESIZE_MODE = 'none'
RPN_ANCHOR_SCALES = (16, 32, 64, 128, 256)
TRAIN_ROIS_PER_IMAGE = 100
STEPS_PER_EPOCH = 5500
VALIDATION_STEPS = 100
config = FashionConfig()
config.display()
class InferenceConfig(FashionConfig):
GPU_COUNT = 1
IMAGES_PER_GPU = 1
inference_config = InferenceConfig()
glob_list = glob.glob('mask_rcnn_fashion_0011.h5')
model_path = glob_list[0] if glob_list else ''
model = modellib.MaskRCNN(mode='inference',
config=inference_config,
model_dir='')
assert model_path != '', "Provide path to trained weights"
print("Loading weights from ", model_path)
model.load_weights(model_path, by_name=True)
with open(".data/label_descriptions.json") as f:
label_descriptions = json.load(f)
label_names = [x['name'] for x in label_descriptions['categories']]
def masker(img_path):
img = cv2.imread(img_path)
img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
mask = model.detect([img])
r = mask[0]
return img, r
def masker_np(img):
img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
mask = model.detect([img])
r = mask[0]
return img, r