In [1]:
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

from data import LoadTest
from model import VGG16
from utils import *

for name in (torch, torchvision, cv2, np):
    print(name.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [3]:
path_image = "./DUTS/DUTS-TE/DUTS-TE-Image/"
path_mask = "./DUTS/DUTS-TE/DUTS-TE-Mask/"

batch_size = 4 #受限于贫穷，4是极限了
target_size = 256

In [4]:
data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)

In [5]:
total_batch = len(data_loader)
total_batch

1255

In [6]:
model = VGG16()
model.load_state_dict(torch.load("./model/model_58_edge.pth"), strict=False)

<All keys matched successfully>

In [7]:
model.eval()
model.to(device)
print()




In [8]:
for layer in model.parameters():
    layer.requires_grad = False

In [9]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1255 acc:0.9211284518241882 pre:0.834391713142395 recall:0.817425012588501 F-measure:0.8258214592933655
Batch:200/1255 acc:0.924027681350708 pre:0.8390196561813354 recall:0.8053625822067261 F-measure:0.825498104095459
Batch:300/1255 acc:0.9228081703186035 pre:0.8387596607208252 recall:0.8030379414558411 F-measure:0.824167788028717
Batch:400/1255 acc:0.9198968410491943 pre:0.8384854197502136 recall:0.8026905059814453 F-measure:0.823792576789856
Batch:500/1255 acc:0.9179408550262451 pre:0.8439954519271851 recall:0.7988864183425903 F-measure:0.8272203207015991
Batch:600/1255 acc:0.9171110391616821 pre:0.8426375985145569 recall:0.8022951483726501 F-measure:0.8270641565322876
Batch:700/1255 acc:0.9192203283309937 pre:0.842157781124115 recall:0.801881730556488 F-measure:0.8266383409500122
Batch:800/1255 acc:0.9161946773529053 pre:0.844218909740448 recall:0.7972199320793152 F-measure:0.8267344236373901
Batch:900/1255 acc:0.916746199131012 pre:0.8301592469215393 recall:0.787182092666