In [1]:
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

from data import LoadTest
from model import VGG16
from utils import *

for name in (torch, torchvision, cv2, np):
    print(name.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [3]:
path_image = "./DUTS/DUTS-TE/DUTS-TE-Image/"
path_mask = "./DUTS/DUTS-TE/DUTS-TE-Mask/"

# path_image = "./DUT-OMROM/DUT-OMRON-image/"
# path_mask = "./DUT-OMROM/pixelwiseGT-new-PNG/"

batch_size = 4 #受限于贫穷，4是极限了
target_size = 256

In [4]:
data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)

In [5]:
total_batch = len(data_loader)
total_batch

1255

In [6]:
model = VGG16()
model.load_state_dict(torch.load("./model/MPFA_35.pth"), strict=False)

<All keys matched successfully>

In [7]:
model.eval()
model.to(device)
print()




In [8]:
for layer in model.parameters():
    layer.requires_grad = False

In [9]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1255 acc:0.9188623428344727 pre:0.8231050968170166 recall:0.8173074722290039 F-measure:0.8171997666358948
Batch:200/1255 acc:0.9224636554718018 pre:0.8251875042915344 recall:0.8134044408798218 F-measure:0.8168741464614868
Batch:300/1255 acc:0.9210756421089172 pre:0.8215697407722473 recall:0.8126248717308044 F-measure:0.8134785294532776
Batch:400/1255 acc:0.9179564714431763 pre:0.8214311599731445 recall:0.814199686050415 F-measure:0.8132904767990112
Batch:500/1255 acc:0.916031539440155 pre:0.8279200196266174 recall:0.809859037399292 F-measure:0.8173554539680481
Batch:600/1255 acc:0.915207028388977 pre:0.8268163800239563 recall:0.8128423690795898 F-measure:0.8173531889915466
Batch:700/1255 acc:0.9173167943954468 pre:0.825914740562439 recall:0.8130220770835876 F-measure:0.8166686296463013
Batch:800/1255 acc:0.9145312309265137 pre:0.8289052248001099 recall:0.8091350197792053 F-measure:0.8180011510848999
Batch:900/1255 acc:0.9146939516067505 pre:0.8113747239112854 recall:0.7994425

In [10]:
path_image = "./DUT-OMROM/DUT-OMRON-image/"
path_mask = "./DUT-OMROM/pixelwiseGT-new-PNG/"

data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)


In [11]:
total_batch = len(data_loader)
total_batch

1292

In [12]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1292 acc:0.9222933650016785 pre:0.7668030261993408 recall:0.735341489315033 F-measure:0.7447482943534851
Batch:200/1292 acc:0.9186984300613403 pre:0.7471588253974915 recall:0.7281694412231445 F-measure:0.7299389243125916
Batch:300/1292 acc:0.9131788611412048 pre:0.7473057508468628 recall:0.7123088240623474 F-measure:0.7259295582771301
Batch:400/1292 acc:0.912373960018158 pre:0.7350950241088867 recall:0.7127103805541992 F-measure:0.714939534664154
Batch:500/1292 acc:0.9133810997009277 pre:0.7295538783073425 recall:0.7169082164764404 F-measure:0.711651086807251
Batch:600/1292 acc:0.9155968427658081 pre:0.7283015251159668 recall:0.7259585857391357 F-measure:0.7131957411766052
Batch:700/1292 acc:0.9161477088928223 pre:0.7293208241462708 recall:0.7245139479637146 F-measure:0.7142395973205566
Batch:800/1292 acc:0.9155284762382507 pre:0.73087477684021 recall:0.7243167757987976 F-measure:0.7155505418777466
Batch:900/1292 acc:0.9171620011329651 pre:0.7324743270874023 recall:0.73215252