In [1]:
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

from data import LoadTest
from model import VGG16
from utils import *

for name in (torch, torchvision, cv2, np):
    print(name.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [5]:
path_image = "./DUTS/DUTS-TE/DUTS-TE-Image/"
path_mask = "./DUTS/DUTS-TE/DUTS-TE-Mask/"

# path_image = "./DUT-OMROM/DUT-OMRON-image/"
# path_mask = "./DUT-OMROM/pixelwiseGT-new-PNG/"

batch_size = 4 #受限于贫穷，4是极限了
target_size = 256

In [6]:
data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)

In [7]:
total_batch = len(data_loader)
total_batch

1255

In [8]:
model = VGG16()
model.load_state_dict(torch.load("./model/model_61_edge.pth"), strict=False)

<All keys matched successfully>

In [9]:
model.eval()
model.to(device)
print()




In [10]:
for layer in model.parameters():
    layer.requires_grad = False

In [11]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1255 acc:0.9208464026451111 pre:0.8328893780708313 recall:0.8161066174507141 F-measure:0.8242954015731812
Batch:200/1255 acc:0.9231030941009521 pre:0.8303602933883667 recall:0.8075926303863525 F-measure:0.8193404674530029
Batch:300/1255 acc:0.9215087294578552 pre:0.8294568657875061 recall:0.8034611940383911 F-measure:0.8169447779655457
Batch:400/1255 acc:0.9189901351928711 pre:0.8314839005470276 recall:0.8038398027420044 F-measure:0.8185088634490967
Batch:500/1255 acc:0.9171459078788757 pre:0.8382471799850464 recall:0.8000597953796387 F-measure:0.822727620601654
Batch:600/1255 acc:0.9164101481437683 pre:0.8382056951522827 recall:0.8022099733352661 F-measure:0.8231571316719055
Batch:700/1255 acc:0.9185934662818909 pre:0.8380021452903748 recall:0.8019123077392578 F-measure:0.8230587244033813
Batch:800/1255 acc:0.9156283140182495 pre:0.840647280216217 recall:0.797565221786499 F-measure:0.8237743973731995
Batch:900/1255 acc:0.915982186794281 pre:0.8250539302825928 recall:0.786543