In [None]:
import cv2
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry

In [None]:
u2net_path = '../u2net.pth'

import model architecture

In [None]:
# type in cmd: wget https://raw.githubusercontent.com/xuebinqin/U-2-Net/master/model/u2net.py -O model.py

# add parent root to import path
import sys, os
sys.path.append(os.path.abspath(".."))

from model import U2NET

load UÂ²-Net model

In [None]:
def load_u2net(u2net_path):
    net = U2NET(3, 1)
    net.load_state_dict(torch.load(u2net_path, map_location='cpu'))
    net.eval()
    return net

image processing helper functions

In [None]:
def predict_mask(image_path, net):
    image = Image.open(image_path).convert('RGB')
    w, h = image.size # keep original dimensions

    transform = transforms.Compose([
        transforms.Resize((320, 320)),
        transforms.ToTensor()
    ])

    # adds batch dimension (1, c, h, w)
    image_tensor = transform(image).unsqueeze(0)

    with torch.no_grad():
        d1, d2, d3, d4, d5, d6, d7 = net(image_tensor)
        pred = d1[:, 0, :, :]
        pred = (pred - pred.min()) / (pred.max() - pred.min())
        mask = pred.squeeze().cpu().numpy()

    # scale back to original size
    mask = cv2.resize(mask, (w, h))

    return mask, image

In [None]:
def apply_mask(image, mask):
    # resize mask to match image
    mask_resized = cv2.resize(mask, (image.width, image.height))
    alpha = (mask_resized * 255).astype(np.uint8)

    # create RGBA
    rgba = image.convert("RGBA")
    rgba.putalpha(Image.fromarray(alpha))

    return rgba

testing the script

In [None]:
if __name__ == '__main__':

    u2net = load_u2net(u2net_path)

In [None]:
if __name__ == '__main__':

    image_path = '../images/img1.jpg'

    subject_mask, image = predict_mask(image_path, u2net)
    result = apply_mask(image, subject_mask)

    # display
    plt.title('Original')
    plt.imshow(image)
    plt.axis('off')
    plt.show()

    plt.title('Masked')
    plt.imshow(subject_mask, cmap='gray')
    plt.axis('off')
    plt.show()

    plt.title('Bg removed')
    plt.imshow(result)
    plt.axis('off')
    plt.show()
