In [1]:
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

from data import LoadTest
from model import VGG16
from utils import *

for name in (torch, torchvision, cv2, np):
    print(name.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [3]:
path_image = "./DUTS/DUTS-TE/DUTS-TE-Image/"
path_mask = "./DUTS/DUTS-TE/DUTS-TE-Mask/"

batch_size = 4 #受限于贫穷，4是极限了
target_size = 256

In [4]:
data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)

In [5]:
total_batch = len(data_loader)
total_batch

1255

In [6]:
model = VGG16()
model.load_state_dict(torch.load("./model/model_51_bce.pth"), strict=False)

<All keys matched successfully>

In [7]:
model.eval()
model.to(device)
print()




In [8]:
for layer in model.parameters():
    layer.requires_grad = False

In [9]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1255 acc:0.9184567332267761 pre:0.8005125522613525 recall:0.8605906963348389 F-measure:0.8089005351066589
Batch:200/1255 acc:0.9209272265434265 pre:0.7991586923599243 recall:0.848027229309082 F-measure:0.8042546510696411
Batch:300/1255 acc:0.9196471571922302 pre:0.7976242899894714 recall:0.8453198671340942 F-measure:0.8020414710044861
Batch:400/1255 acc:0.9174280166625977 pre:0.8010294437408447 recall:0.8472295999526978 F-measure:0.8053238987922668
Batch:500/1255 acc:0.915698230266571 pre:0.8076207637786865 recall:0.8433263301849365 F-measure:0.8097517490386963
Batch:600/1255 acc:0.9149978756904602 pre:0.80825275182724 recall:0.8451533913612366 F-measure:0.8108066916465759
Batch:700/1255 acc:0.9169892072677612 pre:0.8061968684196472 recall:0.8452948927879333 F-measure:0.8093804717063904
Batch:800/1255 acc:0.9141631126403809 pre:0.8098688125610352 recall:0.8405569791793823 F-measure:0.8109840154647827
Batch:900/1255 acc:0.9143108129501343 pre:0.7943586111068726 recall:0.832879