In [1]:
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

from data import LoadTest
from model import VGG16
from utils import *

for name in (torch, torchvision, cv2, np):
    print(name.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [3]:
path_image = "./DUTS/DUTS-TE/DUTS-TE-Image/"
path_mask = "./DUTS/DUTS-TE/DUTS-TE-Mask/"

# path_image = "./DUT-OMROM/DUT-OMRON-image/"
# path_mask = "./DUT-OMROM/pixelwiseGT-new-PNG/"

batch_size = 4 #受限于贫穷，4是极限了
target_size = 256

In [4]:
data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)

In [5]:
total_batch = len(data_loader)
total_batch

1255

In [6]:
model = VGG16()
model.load_state_dict(torch.load("./model/MPFA_9.pth"), strict=False)

<All keys matched successfully>

In [7]:
model.eval()
model.to(device)
print()




In [8]:
for layer in model.parameters():
    layer.requires_grad = False

In [9]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1255 acc:0.91696697473526 pre:0.7943442463874817 recall:0.8440214395523071 F-measure:0.8015055060386658
Batch:200/1255 acc:0.9199618101119995 pre:0.7920089364051819 recall:0.8384428024291992 F-measure:0.7974404692649841
Batch:300/1255 acc:0.9186856150627136 pre:0.7935035824775696 recall:0.8330270648002625 F-measure:0.796626091003418
Batch:400/1255 acc:0.9157906770706177 pre:0.7953324913978577 recall:0.8320383429527283 F-measure:0.7975715398788452
Batch:500/1255 acc:0.9140112996101379 pre:0.8018323183059692 recall:0.8281840682029724 F-measure:0.8021048903465271
Batch:600/1255 acc:0.912882387638092 pre:0.7993711829185486 recall:0.8296405673027039 F-measure:0.8004432320594788
Batch:700/1255 acc:0.9150714874267578 pre:0.798963189125061 recall:0.8299469947814941 F-measure:0.8003630042076111
Batch:800/1255 acc:0.9120721817016602 pre:0.8013990521430969 recall:0.8249518871307373 F-measure:0.8010526299476624
Batch:900/1255 acc:0.9120059013366699 pre:0.7838735580444336 recall:0.8133592

In [10]:
path_image = "./DUT-OMROM/DUT-OMRON-image/"
path_mask = "./DUT-OMROM/pixelwiseGT-new-PNG/"

data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)


In [11]:
total_batch = len(data_loader)
total_batch

1292

In [12]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1292 acc:0.9166428446769714 pre:0.7354612350463867 recall:0.7273966670036316 F-measure:0.7156062722206116
Batch:200/1292 acc:0.912226140499115 pre:0.7153419256210327 recall:0.7083274126052856 F-measure:0.6980043053627014
Batch:300/1292 acc:0.907227098941803 pre:0.7189460396766663 recall:0.6962774991989136 F-measure:0.6979324221611023
Batch:400/1292 acc:0.9070225358009338 pre:0.7067827582359314 recall:0.7001788020133972 F-measure:0.6896087527275085
Batch:500/1292 acc:0.9089599847793579 pre:0.704644501209259 recall:0.7045775651931763 F-measure:0.6895018219947815
Batch:600/1292 acc:0.9108283519744873 pre:0.7008198499679565 recall:0.7134248614311218 F-measure:0.6888173818588257
Batch:700/1292 acc:0.9111068844795227 pre:0.7002547979354858 recall:0.7120121717453003 F-measure:0.6885794997215271
Batch:800/1292 acc:0.9110189080238342 pre:0.703556478023529 recall:0.712993860244751 F-measure:0.6915155649185181
Batch:900/1292 acc:0.912919819355011 pre:0.7072712182998657 recall:0.72096121