In [1]:
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

from data import LoadTest
from model import VGG16
from utils import *

for name in (torch, torchvision, cv2, np):
    print(name.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [3]:
path_image = "./DUTS/DUTS-TE/DUTS-TE-Image/"
path_mask = "./DUTS/DUTS-TE/DUTS-TE-Mask/"

# path_image = "./DUT-OMROM/DUT-OMRON-image/"
# path_mask = "./DUT-OMROM/pixelwiseGT-new-PNG/"

batch_size = 4 #受限于贫穷，4是极限了
target_size = 256

In [4]:
data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)

In [5]:
total_batch = len(data_loader)
total_batch

1255

In [6]:
model = VGG16()
model.load_state_dict(torch.load("./model/MPFA_13.pth"), strict=False)

<All keys matched successfully>

In [7]:
model.eval()
model.to(device)
print()




In [8]:
for layer in model.parameters():
    layer.requires_grad = False

In [9]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1255 acc:0.9154441356658936 pre:0.7836655974388123 recall:0.8492724299430847 F-measure:0.7935712933540344
Batch:200/1255 acc:0.9193875789642334 pre:0.7886249423027039 recall:0.841928243637085 F-measure:0.7950478792190552
Batch:300/1255 acc:0.9183208346366882 pre:0.7895850539207458 recall:0.8377615809440613 F-measure:0.7943596839904785
Batch:400/1255 acc:0.9156412482261658 pre:0.7915717959403992 recall:0.8377980589866638 F-measure:0.7958022952079773
Batch:500/1255 acc:0.9141197204589844 pre:0.7993139624595642 recall:0.8347060680389404 F-measure:0.8014042377471924
Batch:600/1255 acc:0.9131703972816467 pre:0.7982932329177856 recall:0.8358916640281677 F-measure:0.8010583519935608
Batch:700/1255 acc:0.9155195951461792 pre:0.7998061776161194 recall:0.8348545432090759 F-measure:0.8020560145378113
Batch:800/1255 acc:0.9127400517463684 pre:0.8033658266067505 recall:0.8305174708366394 F-measure:0.8036215305328369
Batch:900/1255 acc:0.9129847288131714 pre:0.7876274585723877 recall:0.820

In [10]:
path_image = "./DUT-OMROM/DUT-OMRON-image/"
path_mask = "./DUT-OMROM/pixelwiseGT-new-PNG/"

data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)


In [11]:
total_batch = len(data_loader)
total_batch

1292

In [12]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1292 acc:0.9160940051078796 pre:0.7323072552680969 recall:0.7230894565582275 F-measure:0.7116811871528625
Batch:200/1292 acc:0.9132612347602844 pre:0.7174451351165771 recall:0.7206259965896606 F-measure:0.7038479447364807
Batch:300/1292 acc:0.9087569117546082 pre:0.7217282056808472 recall:0.7097184062004089 F-measure:0.7049781680107117
Batch:400/1292 acc:0.9084843397140503 pre:0.7107465267181396 recall:0.7113921046257019 F-measure:0.696537971496582
Batch:500/1292 acc:0.9107286930084229 pre:0.7118368148803711 recall:0.7182673811912537 F-measure:0.6987098455429077
Batch:600/1292 acc:0.9128850698471069 pre:0.7112278938293457 recall:0.728011429309845 F-measure:0.7004060745239258
Batch:700/1292 acc:0.9133540391921997 pre:0.7113133668899536 recall:0.7263215184211731 F-measure:0.7007025480270386
Batch:800/1292 acc:0.9130309224128723 pre:0.7133192420005798 recall:0.7270708084106445 F-measure:0.7026651501655579
Batch:900/1292 acc:0.9146514534950256 pre:0.7148489952087402 recall:0.7355