In [1]:
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

from data import LoadTest
from model import VGG16
from utils import *

for name in (torch, torchvision, cv2, np):
    print(name.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [3]:
path_image = "./DUTS/DUTS-TE/DUTS-TE-Image/"
path_mask = "./DUTS/DUTS-TE/DUTS-TE-Mask/"

# path_image = "./DUT-OMROM/DUT-OMRON-image/"
# path_mask = "./DUT-OMROM/pixelwiseGT-new-PNG/"

batch_size = 4 #受限于贫穷，4是极限了
target_size = 256

In [4]:
data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)

In [5]:
total_batch = len(data_loader)
total_batch

1255

In [6]:
model = VGG16()
model.load_state_dict(torch.load("./model/MPFA_23.pth"), strict=False)

<All keys matched successfully>

In [7]:
model.eval()
model.to(device)
print()




In [8]:
for layer in model.parameters():
    layer.requires_grad = False

In [9]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1255 acc:0.9206241965293884 pre:0.8311856389045715 recall:0.8233010768890381 F-measure:0.8252055048942566
Batch:200/1255 acc:0.9226438403129578 pre:0.8269964456558228 recall:0.8106570243835449 F-measure:0.8178164958953857
Batch:300/1255 acc:0.921343207359314 pre:0.8259572386741638 recall:0.8087315559387207 F-measure:0.8157474994659424
Batch:400/1255 acc:0.9187116622924805 pre:0.827335774898529 recall:0.8106575608253479 F-measure:0.8172340393066406
Batch:500/1255 acc:0.916764497756958 pre:0.8321555852890015 recall:0.8072196841239929 F-measure:0.8203149437904358
Batch:600/1255 acc:0.9159462451934814 pre:0.8314943909645081 recall:0.8090721368789673 F-measure:0.8205786943435669
Batch:700/1255 acc:0.9180043339729309 pre:0.8312079906463623 recall:0.8081604838371277 F-measure:0.8201680779457092
Batch:800/1255 acc:0.9149702191352844 pre:0.8332365155220032 recall:0.8033570647239685 F-measure:0.8203142285346985
Batch:900/1255 acc:0.9152230620384216 pre:0.8175768852233887 recall:0.79247

In [10]:
path_image = "./DUT-OMROM/DUT-OMRON-image/"
path_mask = "./DUT-OMROM/pixelwiseGT-new-PNG/"

data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)


In [11]:
total_batch = len(data_loader)
total_batch

1292

In [12]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1292 acc:0.9200150966644287 pre:0.7621504664421082 recall:0.711332380771637 F-measure:0.7327218651771545
Batch:200/1292 acc:0.9163889288902283 pre:0.74155592918396 recall:0.7027488350868225 F-measure:0.7175220847129822
Batch:300/1292 acc:0.9113218188285828 pre:0.7465049028396606 recall:0.6896727681159973 F-measure:0.7169625759124756
Batch:400/1292 acc:0.9111647605895996 pre:0.7375941872596741 recall:0.690802812576294 F-measure:0.7092563509941101
Batch:500/1292 acc:0.9131141304969788 pre:0.7345632314682007 recall:0.6968173980712891 F-measure:0.7092639207839966
Batch:600/1292 acc:0.9157516956329346 pre:0.7354938983917236 recall:0.7067985534667969 F-measure:0.7128739953041077
Batch:700/1292 acc:0.9161795377731323 pre:0.7358300089836121 recall:0.7055700421333313 F-measure:0.7136293649673462
Batch:800/1292 acc:0.9159947037696838 pre:0.7383900880813599 recall:0.7071083784103394 F-measure:0.7162060141563416
Batch:900/1292 acc:0.9176831841468811 pre:0.7403017282485962 recall:0.716038