In [1]:
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

from data import LoadTest
from model import VGG16
from utils import *

for name in (torch, torchvision, cv2, np):
    print(name.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [3]:
path_image = "./DUTS/DUTS-TE/DUTS-TE-Image/"
path_mask = "./DUTS/DUTS-TE/DUTS-TE-Mask/"

# path_image = "./DUT-OMROM/DUT-OMRON-image/"
# path_mask = "./DUT-OMROM/pixelwiseGT-new-PNG/"

batch_size = 4 #受限于贫穷，4是极限了
target_size = 256

In [4]:
data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)

In [5]:
total_batch = len(data_loader)
total_batch

1255

In [6]:
model = VGG16()
model.load_state_dict(torch.load("./model/model_63_edge.pth"), strict=False)

<All keys matched successfully>

In [7]:
model.eval()
model.to(device)
print()




In [8]:
for layer in model.parameters():
    layer.requires_grad = False

In [9]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1255 acc:0.9211519956588745 pre:0.8363354206085205 recall:0.81751948595047 F-measure:0.826934278011322
Batch:200/1255 acc:0.923424482345581 pre:0.836095929145813 recall:0.8067910075187683 F-measure:0.823300838470459
Batch:300/1255 acc:0.9222207069396973 pre:0.8344525098800659 recall:0.8035025596618652 F-measure:0.8207411170005798
Batch:400/1255 acc:0.919651985168457 pre:0.8366708159446716 recall:0.8026110529899597 F-measure:0.8221774697303772
Batch:500/1255 acc:0.9177623391151428 pre:0.8427778482437134 recall:0.7984902262687683 F-measure:0.8259250521659851
Batch:600/1255 acc:0.9166973233222961 pre:0.8395371437072754 recall:0.80170738697052 F-measure:0.8244283199310303
Batch:700/1255 acc:0.9188193678855896 pre:0.8389427661895752 recall:0.8014369606971741 F-measure:0.8241152167320251
Batch:800/1255 acc:0.9159017205238342 pre:0.8419498801231384 recall:0.7975098490715027 F-measure:0.8251850605010986
Batch:900/1255 acc:0.9162983894348145 pre:0.8266974091529846 recall:0.78655642271

In [10]:
path_image = "./DUT-OMROM/DUT-OMRON-image/"
path_mask = "./DUT-OMROM/pixelwiseGT-new-PNG/"

data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)


In [11]:
total_batch = len(data_loader)
total_batch

1292

In [12]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1292 acc:0.9244320392608643 pre:0.7868181467056274 recall:0.7281914949417114 F-measure:0.7548874616622925
Batch:200/1292 acc:0.9214763641357422 pre:0.7689417600631714 recall:0.7267668843269348 F-measure:0.7445794343948364
Batch:300/1292 acc:0.9168199896812439 pre:0.7704431414604187 recall:0.7177650332450867 F-measure:0.7434800863265991
Batch:400/1292 acc:0.917252779006958 pre:0.7615127563476562 recall:0.7205209136009216 F-measure:0.7380211353302002
Batch:500/1292 acc:0.9190511107444763 pre:0.760217010974884 recall:0.726555347442627 F-measure:0.7383942008018494
Batch:600/1292 acc:0.9208983778953552 pre:0.7570948004722595 recall:0.7338256239891052 F-measure:0.7382784485816956
Batch:700/1292 acc:0.9210798144340515 pre:0.7558550834655762 recall:0.7313951849937439 F-measure:0.7373591065406799
Batch:800/1292 acc:0.9207730889320374 pre:0.7566506266593933 recall:0.7336471676826477 F-measure:0.738972008228302
Batch:900/1292 acc:0.9221975207328796 pre:0.7581236362457275 recall:0.741752