In [1]:
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

from data import LoadTest
from model import VGG16
from utils import *

for name in (torch, torchvision, cv2, np):
    print(name.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [3]:
path_image = "./DUTS/DUTS-TE/DUTS-TE-Image/"
path_mask = "./DUTS/DUTS-TE/DUTS-TE-Mask/"

batch_size = 4 #受限于贫穷，4是极限了
target_size = 256

In [4]:
data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)

In [5]:
total_batch = len(data_loader)
total_batch

1255

In [8]:
model = VGG16()
model.load_state_dict(torch.load("./model/model_52_edge.pth"), strict=False)

<All keys matched successfully>

In [9]:
model.eval()
model.to(device)
print()




In [10]:
for layer in model.parameters():
    layer.requires_grad = False

In [11]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1255 acc:0.9204803109169006 pre:0.844905436038971 recall:0.7913247346878052 F-measure:0.8267085552215576
Batch:200/1255 acc:0.923371434211731 pre:0.8487091660499573 recall:0.782796323299408 F-measure:0.8264735341072083
Batch:300/1255 acc:0.9224019646644592 pre:0.8487155437469482 recall:0.7814452052116394 F-measure:0.8256976008415222
Batch:400/1255 acc:0.9197404384613037 pre:0.850929856300354 recall:0.7820817232131958 F-measure:0.8274123072624207
Batch:500/1255 acc:0.9174999594688416 pre:0.8559384346008301 recall:0.7766480445861816 F-measure:0.8294091820716858
Batch:600/1255 acc:0.9166660904884338 pre:0.8537827730178833 recall:0.7803995609283447 F-measure:0.8290725350379944
Batch:700/1255 acc:0.9188539385795593 pre:0.8539093136787415 recall:0.7805396914482117 F-measure:0.8293271064758301
Batch:800/1255 acc:0.9159727096557617 pre:0.8571041822433472 recall:0.7758669257164001 F-measure:0.8301609754562378
Batch:900/1255 acc:0.9167600274085999 pre:0.8443728089332581 recall:0.763422