/
demo.py
58 lines (47 loc) · 1.8 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import sys
import torch
cur_path = os.path.abspath(os.path.dirname(__file__))
root_path = os.path.split(cur_path)[0]
sys.path.append(root_path)
from torchvision import transforms
from PIL import Image
from segmentron.utils.visualize import get_color_pallete
from segmentron.models.model_zoo import get_segmentation_model
from segmentron.utils.options import parse_args
from segmentron.utils.default_setup import default_setup
from segmentron.config import cfg
def demo():
args = parse_args()
cfg.update_from_file(args.config_file)
cfg.PHASE = 'test'
cfg.ROOT_PATH = root_path
cfg.check_and_freeze()
default_setup(args)
# output folder
output_dir = os.path.join(cfg.VISUAL.OUTPUT_DIR, 'vis_result_{}_{}_{}_{}'.format(
cfg.MODEL.MODEL_NAME, cfg.MODEL.BACKBONE, cfg.DATASET.NAME, cfg.TIME_STAMP))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# image transform
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
])
model = get_segmentation_model().to(args.device)
model.eval()
if os.path.isdir(args.input_img):
img_paths = [os.path.join(args.input_img, x) for x in os.listdir(args.input_img)]
else:
img_paths = [args.input_img]
for img_path in img_paths:
image = Image.open(img_path).convert('RGB')
images = transform(image).unsqueeze(0).to(args.device)
with torch.no_grad():
output = model(images)
pred = torch.argmax(output[0], 1).squeeze(0).cpu().data.numpy()
mask = get_color_pallete(pred, cfg.DATASET.NAME)
outname = os.path.splitext(os.path.split(img_path)[-1])[0] + '.png'
mask.save(os.path.join(output_dir, outname))
if __name__ == '__main__':
demo()