In [1]:
import os
import sys
import argparse
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

from src.models.modnet import MODNet
# python -m demo.image_matting.colab.inference --input-path data/input --output-path data/mask --ckpt-path pretrained/modnet_photographic_portrait_matting.ckpt

In [2]:
input_path = "data/input"
# mask_path = "data/mask"
output_path = "data/output"
# ckpt_path = "pretrained/modnet_photographic_portrait_matting.ckpt"
ckpt_path = "pretrained/modnet_photographic_portrait_matting.ckpt"

os.getcwd()

'E:\\data\\workspace\\python\\wanxiang-ai\\MODNet'

In [3]:
# check input arguments
if not os.path.exists(input_path):
    print('Cannot find input path: {0}'.format(input_path))
    exit()
if not os.path.exists(output_path):
    print('Cannot find mask path: {0}'.format(output_path))
    exit()
if not os.path.exists(ckpt_path):
    print('Cannot find ckpt path: {0}'.format(ckpt_path))
    exit()

# define hyper-parameters
ref_size = 512

# define image to tensor transform
im_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
)

# create MODNet and load the pre-trained ckpt
modnet = MODNet(backbone_pretrained=False)
modnet = nn.DataParallel(modnet)

if torch.cuda.is_available():
    modnet = modnet.cuda()
    weights = torch.load(ckpt_path)
else:
    weights = torch.load(ckpt_path, map_location=torch.device('cpu'))
modnet.load_state_dict(weights)
modnet.eval()

# inference images
im_names = os.listdir(input_path)
for im_name in im_names:
    print('Process image: {0}'.format(im_name))

    # read image
    im = Image.open(os.path.join(input_path, im_name))

    # unify image channels to 3
    im = np.asarray(im)
    if len(im.shape) == 2:
        im = im[:, :, None]
    if im.shape[2] == 1:
        im = np.repeat(im, 3, axis=2)
    elif im.shape[2] == 4:
        im = im[:, :, 0:3]

    # convert image to PyTorch tensor
    im = Image.fromarray(im)
    im_tensor = im_transform(im)

    # add mini-batch dim
    im_tensor = im_tensor[None, :, :, :]

    # resize image for input
    im_b, im_c, im_h, im_w = im_tensor.shape
    if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
        if im_w >= im_h:
            im_rh = ref_size
            im_rw = int(im_w / im_h * ref_size)
        elif im_w < im_h:
            im_rw = ref_size
            im_rh = int(im_h / im_w * ref_size)
    else:
        im_rh = im_h
        im_rw = im_w

    im_rw = im_rw - im_rw % 32
    im_rh = im_rh - im_rh % 32
    im_tensor_resized = F.interpolate(im_tensor, size=(im_rh, im_rw), mode='area', recompute_scale_factor=False)

    # inference
    _, _, matte = modnet(im_tensor_resized.cuda() if torch.cuda.is_available() else im_tensor_resized, True)

    # resize and save matte
    matte = F.interpolate(matte, size=(im_h, im_w), mode='area', recompute_scale_factor=False)
    matte_np = matte[0][0].data.cpu().numpy()

     # convert matte to range [0, 1]
    matte_np = np.expand_dims(matte_np, axis=2)
    # extract foreground by multiplying original image with matte
    foreground = im * matte_np


    ## 1. 背景黑色
    # convert foreground to PIL Image and save
    # foreground = foreground.astype(np.uint8)
    # foreground_img = Image.fromarray(foreground)
    # foreground_name = im_name.split('.')[0] + '_foreground.png'
    # foreground_img.save(os.path.join(output_path, foreground_name))
    #--------------------------------------------------
    # 2. 生成mask
    # matte_name = im_name.split('.')[0] + '.png'
    # Image.fromarray(((matte * 255).astype('uint8')), mode='L').save(os.path.join(output_path, matte_name))
    #--------------------------------------------------

    # 3. 背景透明
    # create an empty 4-channel image (RGBA) with the same size as the input image
    transparent_foreground = np.zeros((im_h, im_w, 4), dtype=np.uint8)
    # set the RGB channels of the transparent_foreground to the extracted foreground
    transparent_foreground[:, :, :3] = foreground
    # set the alpha channel of the transparent_foreground to the matte (scaled to [0, 255])
    transparent_foreground[:, :, 3] = (matte_np * 255).astype(np.uint8).squeeze()
    # convert transparent_foreground to a PIL Image and save
    transparent_foreground_img = Image.fromarray(transparent_foreground, mode='RGBA')
    transparent_foreground_name = im_name.split('.')[0] + '_transparent_foreground.png'
    transparent_foreground_img.save(os.path.join(output_path, transparent_foreground_name))


Process image: 1.jpg


  "The default behavior for interpolate/upsample with float scale_factor changed "
