In [1]:
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

from data import LoadTest
from model import VGG16
from utils import *

for name in (torch, torchvision, cv2, np):
    print(name.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [5]:
# path_image = "./DUTS/DUTS-TE/DUTS-TE-Image/"
# path_mask = "./DUTS/DUTS-TE/DUTS-TE-Mask/"

path_image = "./DUT-OMROM/DUT-OMRON-image/"
path_mask = "./DUT-OMROM/pixelwiseGT-new-PNG/"

batch_size = 4 #受限于贫穷，4是极限了
target_size = 256

In [6]:
data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)

In [7]:
total_batch = len(data_loader)
total_batch

1292

In [8]:
model = VGG16()
model.load_state_dict(torch.load("./model/model_60_edge.pth"), strict=False)

<All keys matched successfully>

In [9]:
model.eval()
model.to(device)
print()




In [11]:
for layer in model.parameters():
    layer.requires_grad = False

In [12]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1292 acc:0.9234947562217712 pre:0.7790001034736633 recall:0.7335079908370972 F-measure:0.7520996928215027
Batch:200/1292 acc:0.9203187227249146 pre:0.7601647973060608 recall:0.7309510707855225 F-measure:0.7395970225334167
Batch:300/1292 acc:0.9158569574356079 pre:0.7623591423034668 recall:0.7194320559501648 F-measure:0.7383936643600464
Batch:400/1292 acc:0.9163486957550049 pre:0.7547584176063538 recall:0.7217730283737183 F-measure:0.7325254082679749
Batch:500/1292 acc:0.9179418683052063 pre:0.7526240348815918 recall:0.7251181602478027 F-measure:0.7321597933769226
Batch:600/1292 acc:0.9202136397361755 pre:0.7532927393913269 recall:0.7322427034378052 F-measure:0.7343630194664001
Batch:700/1292 acc:0.9207795858383179 pre:0.7545269727706909 recall:0.7294948697090149 F-measure:0.7353994250297546
Batch:800/1292 acc:0.9204019904136658 pre:0.7555352449417114 recall:0.7307113409042358 F-measure:0.7366671562194824
Batch:900/1292 acc:0.9219359159469604 pre:0.7570210099220276 recall:0.73