Inference one image with several different models

In [4]:
# Imports
from PIL import Image
import jittor as jt
from jittor import init
from jittor import nn

import sys
from models.birefnet.birefnet import BiRefNet
from models.udun.udun import UDUN
from models.isnet.isnet import ISNet
from models.mvanet.mvanet import MVANet
from config import Config
from utils import check_state_dict

# Load Model
# Loading model and weights from local disk:


models = ['UDUN', 'ISNet' , 'MVANet', 'BiRefNet'][2:3] # select models to run
weights = {
    'BiRefNet': 'ckpt/BiRefNet/epoch_500.pth',
    'UDUN': 'ckpt/UDUN/udun-trained-R50.pth',
    'ISNet': 'ckpt/ISNet/isnet.pth',
    'MVANet': 'ckpt/MVANet/Model_80.pth',
}
config = Config()

jt.flags.use_cuda = 1

for model_name in models:
    if model_name == 'BiRefNet':
        birefnet = BiRefNet(bb_pretrained=False)
        state_dict = jt.load(weights['BiRefNet'])
        state_dict = check_state_dict(state_dict)
        birefnet.load_state_dict(state_dict)
        birefnet.eval()
        print('BiRefNet is ready to use.')
    elif model_name == 'UDUN':
        udun = UDUN(bb_pretrained=False)
        state_dict = jt.load(weights['UDUN'])
        model_dict = udun.state_dict()
        pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}
        model_dict.update(pretrained_dict)
        udun.load_state_dict(model_dict)
        udun.eval()
        print('UDUN is ready to use.')
    elif model_name == 'ISNet':
        isnet = ISNet()
        state_dict = jt.load(weights['ISNet'])
        model_dict = isnet.state_dict()
        pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}
        model_dict.update(pretrained_dict)
        isnet.load_state_dict(model_dict)
        isnet.eval()
        print('ISNet is ready to use.')
    elif model_name == 'MVANet':
        mvanet = MVANet(bb_pretrained=False)
        state_dict = jt.load(weights['MVANet'])
        model_dict = mvanet.state_dict()
        pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}
        model_dict.update(pretrained_dict)
        mvanet.load_state_dict(model_dict)
        mvanet.eval()
        print('MVANet is ready to use.')
        

# Input Data
mvanet_image = jt.transform.Compose([
    jt.transform.Resize(config.size),
    jt.transform.ImageNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    jt.transform.ToTensor()
])

isnet_image = jt.transform.Compose([
    jt.transform.Resize(config.size),
    jt.transform.ToTensor(),
    jt.transform.ImageNormalize([0.5, 0.5, 0.5], [1, 1, 1])
])

udun_image = jt.transform.Compose([
    jt.transform.Resize(config.size),
    jt.transform.ImageNormalize([124.55 / 255.0, 118.90 / 255.0, 102.94 / 255.0], [56.77 / 255.0, 55.97 / 255.0, 57.50 / 255.0]),
    jt.transform.ToTensor()
])

MVANet is ready to use.


In [None]:
import os
from glob import glob

src_dir = 'tutorials/image_demo'  # image path
image_paths = glob(os.path.join(src_dir, '*'))
dst_dir = 'tutorials/pred_demo'  # prediction path
os.makedirs(dst_dir, exist_ok=True)
for image_path in image_paths:
    print('Processing {} ...'.format(image_path))
    os.makedirs(os.path.join(image_path.replace(src_dir, dst_dir)), exist_ok=True)
    image = Image.open(image_path)
    # Prediction
    for model_name in models:
        with jt.no_grad():
            if model_name == 'BiRefNet':
                input_images = jt.array(mvanet_image(image)).unsqueeze(0)
                preds = birefnet(input_images)[-1].sigmoid()
                pred = preds.squeeze()
                ma = jt.max(pred)
                mi = jt.min(pred)
                pred = (pred - mi) / (ma - mi) * 255
                # Show Results
                pred_pil = jt.transform.ToPILImage()(pred)
                pred_pil.resize(image.size).convert('L').save(os.path.join(image_path.replace(src_dir, dst_dir), f'{model_name}.png'))
            elif model_name == 'MVANet':
                input_images = jt.array(mvanet_image(image)).unsqueeze(0)
                preds = mvanet(input_images).sigmoid()
                pred = preds.squeeze()
                ma = jt.max(pred)
                mi = jt.min(pred)
                pred = (pred - mi) / (ma - mi) * 255
                # Show Results
                pred_pil = jt.transform.ToPILImage()(pred)
                pred_pil.resize(image.size).convert('L').save(os.path.join(image_path.replace(src_dir, dst_dir), f'{model_name}.png'))
            elif model_name == 'UDUN':
                input_images = jt.array(udun_image(image)).unsqueeze(0)
                preds = udun(input_images)[2].sigmoid()
                pred = preds.squeeze()
                ma = jt.max(pred)
                mi = jt.min(pred)
                pred = (pred - mi) / (ma - mi) * 255
                # Show Results
                pred_pil = jt.transform.ToPILImage()(pred)
                pred_pil.resize(image.size).convert('L').save(os.path.join(image_path.replace(src_dir, dst_dir), f'{model_name}.png'))
            elif model_name == 'ISNet':
                input_images = jt.array(isnet_image(image)).unsqueeze(0)
                preds = isnet(input_images)[0][0].sigmoid()
                pred = preds.squeeze()
                ma = jt.max(pred)
                mi = jt.min(pred)
                pred = (pred - mi) / (ma - mi) * 255
                # Show Results
                pred_pil = jt.transform.ToPILImage()(pred)
                pred_pil.resize(image.size).convert('L').save(os.path.join(image_path.replace(src_dir, dst_dir), f'{model_name}.png'))
print('Finished')

Processing tutorials/image_demo/1#Accessories#1#Bag#3811492306_4ae60c73b6_o.jpg ...



Compiling Operators(5/68) used: 2.31s eta: 29.1s 15/68) used: 3.32s eta: 11.7s 16/68) used: 4.32s eta:   14s 24/68) used: 5.33s eta: 9.77s 31/68) used: 6.33s eta: 7.56s 36/68) used: 7.34s eta: 6.52s 40/68) used: 8.34s eta: 5.84s 46/68) used: 9.34s eta: 4.47s 50/68) used: 10.3s eta: 3.73s 59/68) used: 11.4s eta: 1.73s 61/68) used: 12.4s eta: 1.42s 66/68) used: 13.4s eta: 0.405s 68/68) used: 14.4s eta:    0s 


Processing tutorials/image_demo/2#Aircraft#1#Airplane#947427810_e51e389ce9_o.jpg ...
Finished
