In [1]:
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

from data import LoadTest
from model import VGG16
from utils import *

for name in (torch, torchvision, cv2, np):
    print(name.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [3]:
path_image = "./DUTS/DUTS-TE/DUTS-TE-Image/"
path_mask = "./DUTS/DUTS-TE/DUTS-TE-Mask/"

# path_image = "./DUT-OMROM/DUT-OMRON-image/"
# path_mask = "./DUT-OMROM/pixelwiseGT-new-PNG/"

batch_size = 4 #受限于贫穷，4是极限了
target_size = 256

In [4]:
data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)

In [5]:
total_batch = len(data_loader)
total_batch

1255

In [7]:
model = VGG16()
model.load_state_dict(torch.load("./model/model_64_edge.pth"), strict=False)

<All keys matched successfully>

In [8]:
model.eval()
model.to(device)
print()




In [9]:
for layer in model.parameters():
    layer.requires_grad = False

In [10]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1255 acc:0.9203048348426819 pre:0.8327163457870483 recall:0.8179929256439209 F-measure:0.8243775963783264
Batch:200/1255 acc:0.9236326217651367 pre:0.8390056490898132 recall:0.8060387372970581 F-measure:0.8248669505119324
Batch:300/1255 acc:0.9220995306968689 pre:0.8359277844429016 recall:0.8016412258148193 F-measure:0.8210799694061279
Batch:400/1255 acc:0.9195672273635864 pre:0.8387303948402405 recall:0.8018757700920105 F-measure:0.8232810497283936
Batch:500/1255 acc:0.9175517559051514 pre:0.8432453274726868 recall:0.7985095977783203 F-measure:0.8261039853096008
Batch:600/1255 acc:0.9166988134384155 pre:0.8418635725975037 recall:0.8012191653251648 F-measure:0.8258090615272522
Batch:700/1255 acc:0.9187339544296265 pre:0.8405216932296753 recall:0.8009854555130005 F-measure:0.8247803449630737
Batch:800/1255 acc:0.9158930778503418 pre:0.8441334962844849 recall:0.7965521216392517 F-measure:0.8262412548065186
Batch:900/1255 acc:0.9165105819702148 pre:0.8293905258178711 recall:0.78

In [11]:
path_image = "./DUT-OMROM/DUT-OMRON-image/"
path_mask = "./DUT-OMROM/pixelwiseGT-new-PNG/"

data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)


In [12]:
total_batch = len(data_loader)
total_batch

1292

In [13]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1292 acc:0.923247218132019 pre:0.7731168866157532 recall:0.7361138463020325 F-measure:0.7498624920845032
Batch:200/1292 acc:0.9199148416519165 pre:0.7570849657058716 recall:0.7312681674957275 F-measure:0.7379509806632996
Batch:300/1292 acc:0.9153931140899658 pre:0.761807918548584 recall:0.7163340449333191 F-measure:0.73712158203125
Batch:400/1292 acc:0.9156535267829895 pre:0.7532452344894409 recall:0.7179808020591736 F-measure:0.7304732203483582
Batch:500/1292 acc:0.9175575971603394 pre:0.7521862983703613 recall:0.7230528593063354 F-measure:0.731419563293457
Batch:600/1292 acc:0.9194630980491638 pre:0.7494663000106812 recall:0.7291964292526245 F-measure:0.7311453819274902
Batch:700/1292 acc:0.9201138019561768 pre:0.7506314516067505 recall:0.7279254794120789 F-measure:0.7321797609329224
Batch:800/1292 acc:0.9199803471565247 pre:0.7527114748954773 recall:0.7296467423439026 F-measure:0.7344033122062683
Batch:900/1292 acc:0.9216774106025696 pre:0.7547982335090637 recall:0.7379425