In [None]:
import cv2
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

import model architecture

In [None]:
# type in cmd: wget https://raw.githubusercontent.com/xuebinqin/U-2-Net/master/model/u2net.py -O model.py

import sys, os
sys.path.append(os.path.abspath(".."))  # add parent root to import path

from model import U2NET

load model

In [None]:
model_path = '../u2net.pth'
net = U2NET(3, 1)
net.load_state_dict(torch.load(model_path, map_location='cpu'))
net.eval()

image processing helper functions

In [None]:
def load_image(image_path):
    image = Image.open(image_path).convert('RGB')

    transform = transforms.Compose([
        transforms.Resize((320, 320)),
        transforms.ToTensor()
    ])

    image_tensor = transform(image)
    image_tensor = image_tensor.unsqueeze(0) # adds batch dimension (1, c, h, w)

    return image_tensor, image

In [None]:
def predict_mask(image_path):
    image_tensor, image = load_image(image_path)

    with torch.no_grad():
        d1, d2, d3, d4, d5, d6, d7 = net(image_tensor)
        pred = d1[:, 0, :, :]
        pred = (pred - pred.min()) / (pred.max() - pred.min())
        mask = pred.squeeze().cpu().numpy()

    return mask, np.array(image)

In [None]:
def apply_mask(image, mask):
    mask_resized = cv2.resize(mask, (image.shape[1], image.shape[0]))
    rgba = cv2.cvtColor(image, cv2.COLOR_RGB2BGRA)
    rgba[:, :, 3] = (mask_resized * 255).astype(np.uint8)
    return rgba

testing the script

In [None]:
image_path = '../images/img1.png'

In [None]:
mask, image = predict_mask(image_path)
result = apply_mask(image, mask)

In [None]:
plt.title('Original')
plt.imshow(image)
plt.axis('off')
plt.show()

plt.title('Mask')
plt.imshow(mask)
plt.axis('off')
plt.show()

plt.title('BG removed')
plt.imshow(result)
plt.axis('off')
plt.show()