In [1]:
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

from data import LoadTest
from model import VGG16
from utils import *

for name in (torch, torchvision, cv2, np):
    print(name.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [3]:
path_image = "./DUTS/DUTS-TE/DUTS-TE-Image/"
path_mask = "./DUTS/DUTS-TE/DUTS-TE-Mask/"

batch_size = 4 #受限于贫穷，4是极限了
target_size = 256

In [4]:
data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)

In [5]:
total_batch = len(data_loader)
total_batch

1255

In [6]:
model = VGG16()
model.load_state_dict(torch.load("./model/model_49_bce.pth"), strict=False)

<All keys matched successfully>

In [7]:
model.eval()
model.to(device)
print()




In [8]:
for layer in model.parameters():
    layer.requires_grad = False

In [9]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1255 acc:0.919043242931366 pre:0.799809992313385 recall:0.861379861831665 F-measure:0.808846652507782
Batch:200/1255 acc:0.9213315844535828 pre:0.8002069592475891 recall:0.8522963523864746 F-measure:0.8063034415245056
Batch:300/1255 acc:0.9197977185249329 pre:0.7977540493011475 recall:0.8487136363983154 F-measure:0.8032069802284241
Batch:400/1255 acc:0.9171726107597351 pre:0.8009456396102905 recall:0.848357617855072 F-measure:0.8054882884025574
Batch:500/1255 acc:0.9154110550880432 pre:0.8072404265403748 recall:0.8457347750663757 F-measure:0.809937059879303
Batch:600/1255 acc:0.9146113395690918 pre:0.8060632944107056 recall:0.848352313041687 F-measure:0.8097383379936218
Batch:700/1255 acc:0.9167482852935791 pre:0.80545973777771 recall:0.8487510085105896 F-measure:0.8094658255577087
Batch:800/1255 acc:0.9140698909759521 pre:0.8088993430137634 recall:0.8442612886428833 F-measure:0.8110367655754089
Batch:900/1255 acc:0.9142109155654907 pre:0.7926891446113586 recall:0.83648735284