In [None]:
import cv2
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

import model architecture

In [None]:
# type in cmd: wget https://raw.githubusercontent.com/xuebinqin/U-2-Net/master/model/u2net.py -O model.py

# add parent root to import path
import sys, os
sys.path.append(os.path.abspath(".."))

from model import U2NET

load model

In [None]:
def load_u2net(model_path):
    net = U2NET(3, 1)
    net.load_state_dict(torch.load(model_path, map_location='cpu'))
    net.eval()
    return net

image processing helper functions

In [None]:
def load_image(image_path):
    image = Image.open(image_path).convert('RGB')

    transform = transforms.Compose([
        transforms.Resize((320, 320)),
        transforms.ToTensor()
    ])

    image_tensor = transform(image)

    # adds batch dimension (1, c, h, w)
    image_tensor = image_tensor.unsqueeze(0)

    return image_tensor, image

In [None]:
def predict_mask(image_path, net):
    image_tensor, image = load_image(image_path)

    with torch.no_grad():
        d1, d2, d3, d4, d5, d6, d7 = net(image_tensor)
        pred = d1[:, 0, :, :]
        pred = (pred - pred.min()) / (pred.max() - pred.min())
        mask = pred.squeeze().cpu().numpy()

    return mask, np.array(image)

In [None]:
def apply_mask(image, mask):
    mask_resized = cv2.resize(mask, (image.shape[1], image.shape[0]))
    rgba = cv2.cvtColor(image, cv2.COLOR_RGB2BGRA)
    rgba[:, :, 3] = (mask_resized * 255).astype(np.uint8)
    return rgba

testing the script

In [None]:
if __name__ == '__main__':
    
    image_path = '../images/img1.jpg'
    model_path = '../u2net.pth'

    net = load_u2net(model_path)
    mask, image = predict_mask(image_path, net)
    result = apply_mask(image, mask)

    # display
    plt.title('Original')
    plt.imshow(image)
    plt.axis('off')
    plt.show()

    plt.title('Masked')
    plt.imshow(mask)
    plt.axis('off')
    plt.show()


    plt.title('Bg removed')
    plt.imshow(result)
    plt.axis('off')
    plt.show()