In [1]:
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

from data import LoadTest
from model import VGG16
from utils import *

for name in (torch, torchvision, cv2, np):
    print(name.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [3]:
path_image = "./DUTS/DUTS-TE/DUTS-TE-Image/"
path_mask = "./DUTS/DUTS-TE/DUTS-TE-Mask/"

# path_image = "./DUT-OMROM/DUT-OMRON-image/"
# path_mask = "./DUT-OMROM/pixelwiseGT-new-PNG/"

batch_size = 4 #受限于贫穷，4是极限了
target_size = 256

In [4]:
data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)

In [5]:
total_batch = len(data_loader)
total_batch

1255

In [6]:
model = VGG16()
model.load_state_dict(torch.load("./model/MPFA_62.pth"), strict=False)

<All keys matched successfully>

In [7]:
model.eval()
model.to(device)
print()




In [8]:
for layer in model.parameters():
    layer.requires_grad = False

In [9]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1255 acc:0.9167243838310242 pre:0.8123090267181396 recall:0.7890137434005737 F-measure:0.8010938167572021
Batch:200/1255 acc:0.9195241928100586 pre:0.8110002875328064 recall:0.7785226106643677 F-measure:0.7969168424606323
Batch:300/1255 acc:0.9183520674705505 pre:0.810480535030365 recall:0.778235912322998 F-measure:0.7960134148597717
Batch:400/1255 acc:0.9157217144966125 pre:0.8121668100357056 recall:0.7794346213340759 F-measure:0.7974324822425842
Batch:500/1255 acc:0.913576602935791 pre:0.8172253966331482 recall:0.7750516533851624 F-measure:0.8003190159797668
Batch:600/1255 acc:0.9126439690589905 pre:0.8167240619659424 recall:0.7780510187149048 F-measure:0.800808846950531
Batch:700/1255 acc:0.9148070216178894 pre:0.8158077597618103 recall:0.778318464756012 F-measure:0.8003877401351929
Batch:800/1255 acc:0.9120745062828064 pre:0.8203173875808716 recall:0.7738358974456787 F-measure:0.8025531768798828
Batch:900/1255 acc:0.9120819568634033 pre:0.8028462529182434 recall:0.7633957

In [10]:
path_image = "./DUT-OMROM/DUT-OMRON-image/"
path_mask = "./DUT-OMROM/pixelwiseGT-new-PNG/"

data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)


In [11]:
total_batch = len(data_loader)
total_batch

1292

In [12]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1292 acc:0.9179295301437378 pre:0.7464993000030518 recall:0.6984542608261108 F-measure:0.7195662260055542
Batch:200/1292 acc:0.91460782289505 pre:0.7300038933753967 recall:0.699652373790741 F-measure:0.7096100449562073
Batch:300/1292 acc:0.9083604216575623 pre:0.7282354235649109 recall:0.6807042956352234 F-measure:0.7027972936630249
Batch:400/1292 acc:0.908251166343689 pre:0.7147161960601807 recall:0.6829466223716736 F-measure:0.6928987503051758
Batch:500/1292 acc:0.9102231860160828 pre:0.7165209054946899 recall:0.6851282715797424 F-measure:0.6944660544395447
Batch:600/1292 acc:0.9120667576789856 pre:0.7133443355560303 recall:0.691012442111969 F-measure:0.6933930516242981
Batch:700/1292 acc:0.9126990437507629 pre:0.7143604755401611 recall:0.6891725659370422 F-measure:0.6940497756004333
Batch:800/1292 acc:0.9124419689178467 pre:0.7171961069107056 recall:0.6906369924545288 F-measure:0.696654736995697
Batch:900/1292 acc:0.9140454530715942 pre:0.7187913656234741 recall:0.69904470