In [1]:
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

from data import LoadTest
from model import VGG16
from utils import *

for name in (torch, torchvision, cv2, np):
    print(name.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [3]:
path_image = "./DUTS/DUTS-TE/DUTS-TE-Image/"
path_mask = "./DUTS/DUTS-TE/DUTS-TE-Mask/"

# path_image = "./DUT-OMROM/DUT-OMRON-image/"
# path_mask = "./DUT-OMROM/pixelwiseGT-new-PNG/"

batch_size = 4 #受限于贫穷，4是极限了
target_size = 256

In [4]:
data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)

In [5]:
total_batch = len(data_loader)
total_batch

1255

In [7]:
model = VGG16()
model.load_state_dict(torch.load("./model/MPFA_74.pth"), strict=False)

<All keys matched successfully>

In [8]:
model.eval()
model.to(device)
print()




In [9]:
for layer in model.parameters():
    layer.requires_grad = False

In [10]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1255 acc:0.9182862043380737 pre:0.8161500096321106 recall:0.8009535074234009 F-measure:0.807947039604187
Batch:200/1255 acc:0.9214984178543091 pre:0.8213359713554382 recall:0.7928736805915833 F-measure:0.8092164993286133
Batch:300/1255 acc:0.9199885129928589 pre:0.8199112415313721 recall:0.7893044352531433 F-measure:0.8060669302940369
Batch:400/1255 acc:0.9170066714286804 pre:0.8194054365158081 recall:0.7899878025054932 F-measure:0.8056906461715698
Batch:500/1255 acc:0.9148575067520142 pre:0.8245668411254883 recall:0.7867311835289001 F-measure:0.8090230822563171
Batch:600/1255 acc:0.9137870669364929 pre:0.821807861328125 recall:0.7889648079872131 F-measure:0.8076457381248474
Batch:700/1255 acc:0.9160552620887756 pre:0.8223387002944946 recall:0.7884711623191833 F-measure:0.808036744594574
Batch:800/1255 acc:0.9130865931510925 pre:0.8251277804374695 recall:0.7838579416275024 F-measure:0.808673620223999
Batch:900/1255 acc:0.9131733179092407 pre:0.8075164556503296 recall:0.772271

In [11]:
path_image = "./DUT-OMROM/DUT-OMRON-image/"
path_mask = "./DUT-OMROM/pixelwiseGT-new-PNG/"

data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)


In [12]:
total_batch = len(data_loader)
total_batch

1292

In [13]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1292 acc:0.9179481267929077 pre:0.7451664805412292 recall:0.701422393321991 F-measure:0.7189478874206543
Batch:200/1292 acc:0.9153153300285339 pre:0.7344366312026978 recall:0.6937055587768555 F-measure:0.71123868227005
Batch:300/1292 acc:0.9101107716560364 pre:0.7386555671691895 recall:0.6812536120414734 F-measure:0.7093965411186218
Batch:400/1292 acc:0.9106594920158386 pre:0.7301834225654602 recall:0.6857808828353882 F-measure:0.70455402135849
Batch:500/1292 acc:0.9122055768966675 pre:0.7267736792564392 recall:0.6910449862480164 F-measure:0.7038416862487793
Batch:600/1292 acc:0.914008617401123 pre:0.7230703830718994 recall:0.6978892683982849 F-measure:0.7027209401130676
Batch:700/1292 acc:0.9147052764892578 pre:0.7246558666229248 recall:0.6984393000602722 F-measure:0.7046276330947876
Batch:800/1292 acc:0.9145584106445312 pre:0.7278995513916016 recall:0.6996550559997559 F-measure:0.7075554728507996
Batch:900/1292 acc:0.9162071347236633 pre:0.729479968547821 recall:0.707365691