In [1]:
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

from data import LoadTest
from model import VGG16
from utils import *

for name in (torch, torchvision, cv2, np):
    print(name.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [3]:
path_image = "./DUTS/DUTS-TE/DUTS-TE-Image/"
path_mask = "./DUTS/DUTS-TE/DUTS-TE-Mask/"

batch_size = 4 #受限于贫穷，4是极限了
target_size = 256

In [4]:
data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)

In [5]:
total_batch = len(data_loader)
total_batch

1255

In [6]:
model = VGG16()
model.load_state_dict(torch.load("./model/model_56_edge.pth"), strict=False)

<All keys matched successfully>

In [7]:
model.eval()
model.to(device)
print()




In [8]:
for layer in model.parameters():
    layer.requires_grad = False

In [9]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1255 acc:0.9213927984237671 pre:0.8366588950157166 recall:0.8173002600669861 F-measure:0.827576756477356
Batch:200/1255 acc:0.9237785339355469 pre:0.8378849625587463 recall:0.8066458106040955 F-measure:0.8254958987236023
Batch:300/1255 acc:0.9220896363258362 pre:0.8353559970855713 recall:0.8037175536155701 F-measure:0.8220272064208984
Batch:400/1255 acc:0.9194955229759216 pre:0.8360452651977539 recall:0.8052409291267395 F-measure:0.8228281140327454
Batch:500/1255 acc:0.9176307916641235 pre:0.8417961597442627 recall:0.801300585269928 F-measure:0.8260688185691833
Batch:600/1255 acc:0.9167395234107971 pre:0.8406297564506531 recall:0.8034589290618896 F-measure:0.8258376121520996
Batch:700/1255 acc:0.9188424348831177 pre:0.839729368686676 recall:0.8031342625617981 F-measure:0.825168251991272
Batch:800/1255 acc:0.916013240814209 pre:0.8427293300628662 recall:0.7987781167030334 F-measure:0.826183557510376
Batch:900/1255 acc:0.916567862033844 pre:0.828984260559082 recall:0.7883888483