In [1]:
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

from data import LoadTest
from model import VGG16
from utils import *

for name in (torch, torchvision, cv2, np):
    print(name.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [3]:
path_image = "./DUTS/DUTS-TE/DUTS-TE-Image/"
path_mask = "./DUTS/DUTS-TE/DUTS-TE-Mask/"

batch_size = 4 #受限于贫穷，4是极限了
target_size = 256

In [4]:
data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)

In [5]:
total_batch = len(data_loader)
total_batch

1255

In [6]:
model = VGG16()
model.load_state_dict(torch.load("./model/model_54_edge.pth"), strict=False)

<All keys matched successfully>

In [7]:
model.eval()
model.to(device)
print()




In [8]:
for layer in model.parameters():
    layer.requires_grad = False

In [9]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1255 acc:0.9209011793136597 pre:0.834033727645874 recall:0.816321074962616 F-measure:0.8251097798347473
Batch:200/1255 acc:0.923833966255188 pre:0.8374187350273132 recall:0.8066933155059814 F-measure:0.824855387210846
Batch:300/1255 acc:0.9222798943519592 pre:0.8354811072349548 recall:0.8035905361175537 F-measure:0.8220124244689941
Batch:400/1255 acc:0.9195767045021057 pre:0.836449921131134 recall:0.8042813539505005 F-measure:0.8228251338005066
Batch:500/1255 acc:0.9175930023193359 pre:0.8414940237998962 recall:0.8002154231071472 F-measure:0.825809895992279
Batch:600/1255 acc:0.9167186617851257 pre:0.840255081653595 recall:0.8027148246765137 F-measure:0.8255079984664917
Batch:700/1255 acc:0.918907642364502 pre:0.8401086330413818 recall:0.8026100993156433 F-measure:0.825492799282074
Batch:800/1255 acc:0.9158862233161926 pre:0.8422440886497498 recall:0.7975364327430725 F-measure:0.8257001042366028
Batch:900/1255 acc:0.9164173007011414 pre:0.8273929953575134 recall:0.78603488206