In [1]:
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

from data import LoadTest
from model import VGG16
from utils import *

for name in (torch, torchvision, cv2, np):
    print(name.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [3]:
path_image = "./DUTS/DUTS-TE/DUTS-TE-Image/"
path_mask = "./DUTS/DUTS-TE/DUTS-TE-Mask/"

# path_image = "./DUT-OMROM/DUT-OMRON-image/"
# path_mask = "./DUT-OMROM/pixelwiseGT-new-PNG/"

batch_size = 4 #受限于贫穷，4是极限了
target_size = 256

In [4]:
data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)

In [5]:
total_batch = len(data_loader)
total_batch

1255

In [6]:
model = VGG16()
model.load_state_dict(torch.load("./model/MPFA_44.pth"), strict=False)

<All keys matched successfully>

In [7]:
model.eval()
model.to(device)
print()




In [8]:
for layer in model.parameters():
    layer.requires_grad = False

In [9]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1255 acc:0.9196802973747253 pre:0.8209121227264404 recall:0.8303541541099548 F-measure:0.8188225626945496
Batch:200/1255 acc:0.9229786396026611 pre:0.8262304663658142 recall:0.8197686672210693 F-measure:0.8196090459823608
Batch:300/1255 acc:0.9219977259635925 pre:0.8270219564437866 recall:0.8157817721366882 F-measure:0.8187811970710754
Batch:400/1255 acc:0.9191440939903259 pre:0.8283149003982544 recall:0.8162267208099365 F-measure:0.8195270895957947
Batch:500/1255 acc:0.9170910120010376 pre:0.8334200382232666 recall:0.8116655945777893 F-measure:0.8224098086357117
Batch:600/1255 acc:0.9161903858184814 pre:0.8316251039505005 recall:0.8148596882820129 F-measure:0.8219941854476929
Batch:700/1255 acc:0.9183616042137146 pre:0.8314561247825623 recall:0.8149453401565552 F-measure:0.821959376335144
Batch:800/1255 acc:0.9153560996055603 pre:0.833585798740387 recall:0.8101394176483154 F-measure:0.8221545219421387
Batch:900/1255 acc:0.9154444336891174 pre:0.8158050179481506 recall:0.7999

In [10]:
path_image = "./DUT-OMROM/DUT-OMRON-image/"
path_mask = "./DUT-OMROM/pixelwiseGT-new-PNG/"

data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)


In [11]:
total_batch = len(data_loader)
total_batch

1292

In [12]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1292 acc:0.9192777872085571 pre:0.7535431385040283 recall:0.7260148525238037 F-measure:0.7315698862075806
Batch:200/1292 acc:0.9168419241905212 pre:0.7371069192886353 recall:0.7208317518234253 F-measure:0.7204105854034424
Batch:300/1292 acc:0.9124143719673157 pre:0.7415094375610352 recall:0.7097541689872742 F-measure:0.7211188673973083
Batch:400/1292 acc:0.9120386838912964 pre:0.7314044833183289 recall:0.7113921046257019 F-measure:0.7131035327911377
Batch:500/1292 acc:0.9135894179344177 pre:0.7303426861763 recall:0.715121865272522 F-measure:0.7130498290061951
Batch:600/1292 acc:0.9157190918922424 pre:0.7289348244667053 recall:0.7238461971282959 F-measure:0.714523434638977
Batch:700/1292 acc:0.9162582755088806 pre:0.7301914095878601 recall:0.7219455242156982 F-measure:0.7154025435447693
Batch:800/1292 acc:0.9158943891525269 pre:0.7326093912124634 recall:0.7218431830406189 F-measure:0.7169319987297058
Batch:900/1292 acc:0.9175925850868225 pre:0.7346745729446411 recall:0.7294052