In [1]:
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

from data import LoadTest
from model import VGG16
from utils import *

for name in (torch, torchvision, cv2, np):
    print(name.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [3]:
path_image = "./DUTS/DUTS-TE/DUTS-TE-Image/"
path_mask = "./DUTS/DUTS-TE/DUTS-TE-Mask/"

batch_size = 4 #受限于贫穷，4是极限了
target_size = 256

In [4]:
data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)

In [5]:
total_batch = len(data_loader)
total_batch

1255

In [6]:
model = VGG16()
model.load_state_dict(torch.load("./model/model_57_edge.pth"), strict=False)

<All keys matched successfully>

In [7]:
model.eval()
model.to(device)
print()




In [8]:
for layer in model.parameters():
    layer.requires_grad = False

In [9]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1255 acc:0.9212790727615356 pre:0.835616409778595 recall:0.8200217485427856 F-measure:0.827032744884491
Batch:200/1255 acc:0.923994243144989 pre:0.8382697701454163 recall:0.8068317174911499 F-measure:0.8248860239982605
Batch:300/1255 acc:0.9223272800445557 pre:0.8382089734077454 recall:0.8002394437789917 F-measure:0.8222478032112122
Batch:400/1255 acc:0.919816792011261 pre:0.8406327366828918 recall:0.8015360236167908 F-measure:0.8243542313575745
Batch:500/1255 acc:0.9178910255432129 pre:0.8451947569847107 recall:0.799152672290802 F-measure:0.8273249864578247
Batch:600/1255 acc:0.9170483350753784 pre:0.8430220484733582 recall:0.8030405640602112 F-measure:0.8268572092056274
Batch:700/1255 acc:0.9191675782203674 pre:0.8416422009468079 recall:0.8036831021308899 F-measure:0.8263193964958191
Batch:800/1255 acc:0.9160598516464233 pre:0.8435882925987244 recall:0.7986080646514893 F-measure:0.826422393321991
Batch:900/1255 acc:0.9166615605354309 pre:0.8300426602363586 recall:0.78787404