In [1]:
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

from data import LoadTest
from model import VGG16
from utils import *

for name in (torch, torchvision, cv2, np):
    print(name.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [3]:
path_image = "./DUTS/DUTS-TE/DUTS-TE-Image/"
path_mask = "./DUTS/DUTS-TE/DUTS-TE-Mask/"

batch_size = 4 #受限于贫穷，4是极限了
target_size = 256

In [4]:
data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)

In [5]:
total_batch = len(data_loader)
total_batch

1255

In [10]:
model = VGG16()
model.load_state_dict(torch.load("./model/model_48_bce.pth"), strict=False)

<All keys matched successfully>

In [11]:
model.eval()
model.to(device)
print()




In [12]:
for layer in model.parameters():
    layer.requires_grad = False

In [13]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1255 acc:0.9188006520271301 pre:0.7998426556587219 recall:0.8570958375930786 F-measure:0.8084877729415894
Batch:200/1255 acc:0.9210891723632812 pre:0.798721969127655 recall:0.8474853038787842 F-measure:0.8044276833534241
Batch:300/1255 acc:0.9196173548698425 pre:0.7964658141136169 recall:0.8480572700500488 F-measure:0.8021252751350403
Batch:400/1255 acc:0.9172965884208679 pre:0.800870418548584 recall:0.8500706553459167 F-measure:0.8059403300285339
Batch:500/1255 acc:0.9157732725143433 pre:0.8078138828277588 recall:0.8468021154403687 F-measure:0.8110573887825012
Batch:600/1255 acc:0.9147688150405884 pre:0.8065034747123718 recall:0.8483108878135681 F-measure:0.8102525472640991
Batch:700/1255 acc:0.9169183969497681 pre:0.8056395649909973 recall:0.8486962914466858 F-measure:0.809839129447937
Batch:800/1255 acc:0.9141675233840942 pre:0.8090346455574036 recall:0.8442304134368896 F-measure:0.8112450242042542
Batch:900/1255 acc:0.9142275452613831 pre:0.7920889258384705 recall:0.83586