In [1]:
import os
import glob
import torch
import numpy as np

from PIL import Image
from skimage import io
from os.path import sep, join
from torchvision import transforms
from torch.autograd import Variable
from torch.utils.data import DataLoader

from model.basnet import BASNet
from data_loader import SalObjDataset, RescaleT, ToTensorLab

## Task-4 基于BASNet的抠图

其实Task-2已经完成这个项目的绝大部分内容了，只需要稍微调整一下保存函数，添加一个mask就可以了。
所以只需要添加以下代码：
```python
def save_matting_output(image_path: str, pred, output_dir=join(".", "output")):
    ...

    mask = np.array(img.convert('L'))
    masked = np.dstack((original_image, mask))
    masked_img = Image.fromarray(masked, 'RGBA')

    ...
```
这样就好啦~

In [2]:
def normalize(tensor):
    max_val = torch.max(tensor)
    min_val = torch.min(tensor)
    return (tensor - min_val) / (max_val - min_val)

def save_matting_output(image_path: str, pred, output_dir=join(".", "output")):
    if os.path.exists(output_dir) is False:
        os.mkdir(output_dir)
    
    pred = pred.squeeze()
    pred = pred.cpu().data.numpy()

    img = Image.fromarray(pred * 255).convert('RGB')
    image_name = image_path.split(sep)[-1]
    original_image = io.imread(image_path)
    img = img.resize((original_image.shape[1], original_image.shape[0]), 
                     Image.BILINEAR)
    
    mask = np.array(img.convert('L'))
    masked = np.dstack((original_image, mask))
    masked_img = Image.fromarray(masked, 'RGBA')

    filename = '.'.join(image_name.split('.')[:-1])
    masked_img.save(join(output_dir, filename + '_matting.png'))


In [3]:
img_dir = join('.', 'test', 'img-matting')
pred_dir = join('.', 'test', 'img-matting', 'outputs')
model_dir = join('.', 'saved_models', 'basnet_bsi', 'basnet.pth')

img_path_list = glob.glob(join(img_dir, '*.jpg'))

test_salobj_dataset = SalObjDataset(img_name_list = img_path_list, 
                                    lbl_name_list = [],
                                    transform = transforms.Compose([
                                        RescaleT(256),
                                        ToTensorLab(flag=0)
                                    ]))
test_salobj_dataloader = DataLoader(test_salobj_dataset, batch_size=1, 
                                    shuffle=False, num_workers=1)

In [4]:
print("...Loading BASNet...")
net = BASNet(3, 1)
net.load_state_dict(torch.load(model_dir))
if torch.cuda.is_available():
    net.cuda()
net.eval()

if __name__ == "__main__":
    for i, data in enumerate(test_salobj_dataloader):
        print(f"Inference on image {i+1}")
        print(f"Image Name: {img_path_list[i].split(sep)[-1]}")
        
        inputs = data['image']
        inputs = inputs.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs = Variable(inputs.cuda())
        else:
            inputs = Variable(inputs)
        
        d1, *_ = net(inputs)

        pred = d1[:, 0, :, :]
        pred = normalize(pred)

        save_matting_output(img_path_list[i], pred, pred_dir)

...Loading BASNet...




Inference on image 1
Image Name: alpaca.jpg
Inference on image 2
Image Name: cat.jpg
Inference on image 3
Image Name: character-C.jpg
