In [1]:
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

from data import LoadTest
from model import VGG16
from utils import *

for name in (torch, torchvision, cv2, np):
    print(name.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [3]:
path_image = "./DUTS/DUTS-TE/DUTS-TE-Image/"
path_mask = "./DUTS/DUTS-TE/DUTS-TE-Mask/"

# path_image = "./DUT-OMROM/DUT-OMRON-image/"
# path_mask = "./DUT-OMROM/pixelwiseGT-new-PNG/"

batch_size = 4 #受限于贫穷，4是极限了
target_size = 256

In [4]:
data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)

In [5]:
total_batch = len(data_loader)
total_batch

1255

In [10]:
model = VGG16()
model.load_state_dict(torch.load("./model/MPFA_51.pth"), strict=False)

<All keys matched successfully>

In [11]:
model.eval()
model.to(device)
print()




In [12]:
for layer in model.parameters():
    layer.requires_grad = False

In [13]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1255 acc:0.9191281199455261 pre:0.816030740737915 recall:0.8355982303619385 F-measure:0.8160198926925659
Batch:200/1255 acc:0.9224180579185486 pre:0.8213722705841064 recall:0.8251014351844788 F-measure:0.8167349100112915
Batch:300/1255 acc:0.9209890961647034 pre:0.8190389275550842 recall:0.8223370909690857 F-measure:0.8137248158454895
Batch:400/1255 acc:0.9185909032821655 pre:0.8223724961280823 recall:0.8223307728767395 F-measure:0.8162429332733154
Batch:500/1255 acc:0.9169034957885742 pre:0.8293876647949219 recall:0.819004476070404 F-measure:0.8211193680763245
Batch:600/1255 acc:0.9159166812896729 pre:0.8277755379676819 recall:0.8209471106529236 F-measure:0.8203755617141724
Batch:700/1255 acc:0.9181726574897766 pre:0.8279935717582703 recall:0.8210899233818054 F-measure:0.8207377195358276
Batch:800/1255 acc:0.9154171347618103 pre:0.8316827416419983 recall:0.8164752125740051 F-measure:0.8222788572311401
Batch:900/1255 acc:0.9157974123954773 pre:0.815994918346405 recall:0.80486

In [14]:
path_image = "./DUT-OMROM/DUT-OMRON-image/"
path_mask = "./DUT-OMROM/pixelwiseGT-new-PNG/"

data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)


In [15]:
total_batch = len(data_loader)
total_batch

1292

In [None]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1292 acc:0.9205004572868347 pre:0.7606207728385925 recall:0.7369702458381653 F-measure:0.7387913465499878
Batch:200/1292 acc:0.9172635674476624 pre:0.7403751015663147 recall:0.7279025316238403 F-measure:0.7247220277786255
Batch:300/1292 acc:0.9117786884307861 pre:0.7394692897796631 recall:0.7124938368797302 F-measure:0.7195568680763245
Batch:400/1292 acc:0.9118021726608276 pre:0.7313506007194519 recall:0.7132731676101685 F-measure:0.7123115658760071
Batch:500/1292 acc:0.9135990738868713 pre:0.7298102378845215 recall:0.7164260745048523 F-measure:0.7121318578720093
Batch:600/1292 acc:0.9155578017234802 pre:0.7266045808792114 recall:0.7236361503601074 F-measure:0.7118256092071533
Batch:700/1292 acc:0.9161108136177063 pre:0.7272526621818542 recall:0.7215961217880249 F-measure:0.7125101685523987
Batch:800/1292 acc:0.9157483577728271 pre:0.7295728325843811 recall:0.7221245169639587 F-measure:0.7138474583625793
Batch:900/1292 acc:0.9173011779785156 pre:0.731046199798584 recall:0.730