In [1]:
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

from data import LoadTest
from model import VGG16
from utils import *

for name in (torch, torchvision, cv2, np):
    print(name.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [3]:
path_image = "./DUTS/DUTS-TE/DUTS-TE-Image/"
path_mask = "./DUTS/DUTS-TE/DUTS-TE-Mask/"

batch_size = 4 #受限于贫穷，4是极限了
target_size = 256

In [4]:
data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)

In [5]:
total_batch = len(data_loader)
total_batch

1255

In [6]:
model = VGG16()
model.load_state_dict(torch.load("./model/model_50_bce.pth"), strict=False)

<All keys matched successfully>

In [7]:
model.eval()
model.to(device)
print()




In [8]:
for layer in model.parameters():
    layer.requires_grad = False

In [9]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1255 acc:0.9186961054801941 pre:0.7985026240348816 recall:0.8627091646194458 F-measure:0.8084118962287903
Batch:200/1255 acc:0.9220625162124634 pre:0.806028425693512 recall:0.8517464399337769 F-measure:0.8112859129905701
Batch:300/1255 acc:0.9202210903167725 pre:0.8024780750274658 recall:0.8466841578483582 F-measure:0.8067122101783752
Batch:400/1255 acc:0.9176804423332214 pre:0.8050880432128906 recall:0.8465073704719543 F-measure:0.8087626695632935
Batch:500/1255 acc:0.916016161441803 pre:0.8108566999435425 recall:0.8437850475311279 F-measure:0.8128752112388611
Batch:600/1255 acc:0.9152250289916992 pre:0.810840368270874 recall:0.845960795879364 F-measure:0.813424825668335
Batch:700/1255 acc:0.9173814058303833 pre:0.809857189655304 recall:0.8465344309806824 F-measure:0.8128975033760071
Batch:800/1255 acc:0.9147121906280518 pre:0.8141668438911438 recall:0.8416293263435364 F-measure:0.8148608803749084
Batch:900/1255 acc:0.9146795868873596 pre:0.7969505786895752 recall:0.83404982