Inference one image with several different models

In [None]:
# Imports
from PIL import Image
import jittor as jt
from jittor import init
from jittor import nn

import sys
sys.path.insert(0, "../")
from models.birefnet.birefnet import BiRefNet
from models.udun.udun import UDUN
from models.isnet.isnet import ISNet


# Load Model
# Loading model and weights from local disk:

from utils import check_state_dict

models = ['BiRefNet', 'UDUN', 'ISNet']

for model_name in models:
    if model_name == 'BiRefNet':
        birefnet = BiRefNet(bb_pretrained=False)
        state_dict = jt.load('')
        state_dict = check_state_dict(state_dict)
        birefnet.load_state_dict(state_dict)
        birefnet.eval()
        print('BiRefNet is ready to use.')
    elif model_name == 'UDUN':
        udun = UDUN(bb_pretrained=False)
        udun.load_state_dict(jt.load(''))
        udun.eval()
        print('UDUN is ready to use.')
    elif model_name == 'ISNet':
        isnet = ISNet()
        isnet.load_state_dict(jt.load(''))
        isnet.eval()
        print('ISNet is ready to use.')

# Input Data
transform_image = jt.transforms.Compose([
    jt.transforms.Resize((1024, 1024)),
    jt.transforms.ToTensor(),
    jt.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
import os
from glob import glob

src_dir = 'image_demo'
image_paths = glob(os.path.join(src_dir, '*'))
dst_dir = 'pred_demo'
os.makedirs(dst_dir, exist_ok=True)
for image_path in image_paths:
    print('Processing {} ...'.format(image_path))
    image = Image.open(image_path)
    input_images = transform_image(image).unsqueeze(0)

    # Prediction
    for model_name in models:
        if model_name == 'BiRefNet':
            preds = birefnet(input_images)[-1].sigmoid()
            pred = preds[0].squeeze()
            # Show Results
            pred_pil = jt.transforms.ToPILImage()(pred)
            pred_pil.resize(image.size).save(image_path.replace(src_dir, dst_dir))
        elif model_name == 'UDUN':
            scaled_preds = udun(input_images)[2].sigmoid()
            
    