In [7]:
import os
import glob
import torch

from PIL import Image
from skimage import io
from os.path import sep, join
from torchvision import transforms
from torch.autograd import Variable
from torch.utils.data import DataLoader

from model import BASNet
from data_loader import SalObjDataset, RescaleT, ToTensorLab

## 两个基本函数

1. 标准化处理预测值
2. 保存输出


In [8]:
def normalize(tensor):
    max_val = torch.max(tensor)
    min_val = torch.min(tensor)
    return (tensor - min_val) / (max_val - min_val)

def save_output(image_path: str, pred, output_dir=join(".", "output")):
    if os.path.exists(output_dir) is False:
        os.mkdir(output_dir)
    
    pred = pred.squeeze()
    pred = pred.cpu().data.numpy()

    img = Image.fromarray(pred * 255).convert('RGB')
    image_name = image_path.split(sep)[-1]
    original_image = io.imread(image_path)
    img = img.resize((original_image.shape[1], original_image.shape[0]), 
                     Image.BILINEAR)
    filename = '.'.join(image_name.split('.')[:-1])
    img.save(join(output_dir, filename + '_pred.png'))


## 测试

### 基本变量

In [9]:
img_dir = join('.', 'test', 'recaptcha', 'pictures')
pred_dir = join('.', 'test', 'recaptcha', 'predictions')
model_dir = join('.', 'saved_models', 'basnet_bsi', 'basnet.pth')

img_path_list = glob.glob(join(img_dir, '*.png'))

### 数据集

In [10]:
test_salobj_dataset = SalObjDataset(img_name_list = img_path_list, 
                                    lbl_name_list = [],
                                    transform = transforms.Compose([
                                        RescaleT(256),
                                        ToTensorLab(flag=0)
                                    ]))
test_salobj_dataloader = DataLoader(test_salobj_dataset, batch_size=1, 
                                    shuffle=False, num_workers=1)

### 定义模型

In [11]:
print("...Loading BASNet...")
net = BASNet(3, 1)
net.load_state_dict(torch.load(model_dir))
if torch.cuda.is_available():
    net.cuda()
net.eval()

...Loading BASNet...


BASNet(
  (inconv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (inbn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (inrelu): ReLU(inplace=True)
  (encoder1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
   

In [12]:
for i, data in enumerate(test_salobj_dataloader):
    print(f"Inference on image {i+1}")
    print(f"Image Name: {img_path_list[i].split(sep)[-1]}")
    
    inputs = data['image']
    inputs = inputs.type(torch.FloatTensor)

    if torch.cuda.is_available():
        inputs = Variable(inputs.cuda())
    else:
        inputs = Variable(inputs)
    
    d1, *_ = net(inputs)

    pred = d1[:, 0, :, :]
    pred = normalize(pred)

    save_output(img_path_list[i], pred, pred_dir)

Inference on image 1
Image Name: 0K4L.png
Inference on image 2
Image Name: 0Q0T.png
Inference on image 3
Image Name: 23I7.png
Inference on image 4
Image Name: 27AR.png
Inference on image 5
Image Name: 28JG.png
Inference on image 6
Image Name: 29IC.png
Inference on image 7
Image Name: 2DF5.png
Inference on image 8
Image Name: 2NXQ.png
Inference on image 9
Image Name: 2S5M.png
Inference on image 10
Image Name: 2WAI.png
Inference on image 11
Image Name: 3CT7.png
Inference on image 12
Image Name: 3MKJ.png
Inference on image 13
Image Name: 3VBC.png
Inference on image 14
Image Name: 47TB.png
Inference on image 15
Image Name: 4APP.png
Inference on image 16
Image Name: 4UXQ.png
Inference on image 17
Image Name: 54AC.png
Inference on image 18
Image Name: 5GNW.png
Inference on image 19
Image Name: 5IVN.png
Inference on image 20
Image Name: 5QLJ.png
Inference on image 21
Image Name: 65KH.png
Inference on image 22
Image Name: 6IGU.png
Inference on image 23
Image Name: 6LDJ.png
Inference on image 2