In [1]:
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

from data import LoadTest
from model import VGG16
from utils import *

for name in (torch, torchvision, cv2, np):
    print(name.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [3]:
path_image = "./DUTS/DUTS-TE/DUTS-TE-Image/"
path_mask = "./DUTS/DUTS-TE/DUTS-TE-Mask/"

batch_size = 4 #受限于贫穷，4是极限了
target_size = 256

In [4]:
data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)

In [5]:
total_batch = len(data_loader)
total_batch

1255

In [6]:
model = VGG16()
model.load_state_dict(torch.load("./model/model_55_edge.pth"), strict=False)

<All keys matched successfully>

In [7]:
model.eval()
model.to(device)
print()




In [8]:
for layer in model.parameters():
    layer.requires_grad = False

In [9]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1255 acc:0.9209449291229248 pre:0.8354966640472412 recall:0.8186913132667542 F-measure:0.8269980549812317
Batch:200/1255 acc:0.9231404066085815 pre:0.8332966566085815 recall:0.8061332702636719 F-measure:0.8208922147750854
Batch:300/1255 acc:0.9220535755157471 pre:0.834407389163971 recall:0.801975429058075 F-measure:0.820528507232666
Batch:400/1255 acc:0.9194115400314331 pre:0.8355841040611267 recall:0.8025632500648499 F-measure:0.8214250802993774
Batch:500/1255 acc:0.91744464635849 pre:0.8409979939460754 recall:0.7977848649024963 F-measure:0.8243734836578369
Batch:600/1255 acc:0.9166619181632996 pre:0.8400440216064453 recall:0.8011245131492615 F-measure:0.8246068954467773
Batch:700/1255 acc:0.9188134074211121 pre:0.8398867249488831 recall:0.8006138205528259 F-measure:0.8245748281478882
Batch:800/1255 acc:0.9159217476844788 pre:0.84172123670578 recall:0.7971342206001282 F-measure:0.8250235319137573
Batch:900/1255 acc:0.9166410565376282 pre:0.8278377056121826 recall:0.787015914