In [1]:
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

from data import LoadTest
from model import VGG16
from utils import *

for name in (torch, torchvision, cv2, np):
    print(name.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [3]:
path_image = "./DUTS/DUTS-TE/DUTS-TE-Image/"
path_mask = "./DUTS/DUTS-TE/DUTS-TE-Mask/"

batch_size = 4 #受限于贫穷，4是极限了
target_size = 256

In [4]:
data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)

In [5]:
total_batch = len(data_loader)
total_batch

1255

In [6]:
model = VGG16()
model.load_state_dict(torch.load("./model/model_53_edge.pth"), strict=False)

<All keys matched successfully>

In [7]:
model.eval()
model.to(device)
print()




In [8]:
for layer in model.parameters():
    layer.requires_grad = False

In [9]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1255 acc:0.9215604662895203 pre:0.8447538614273071 recall:0.8071174621582031 F-measure:0.8311780095100403
Batch:200/1255 acc:0.9240925908088684 pre:0.8449927568435669 recall:0.7990853786468506 F-measure:0.8283079862594604
Batch:300/1255 acc:0.922678530216217 pre:0.8432189226150513 recall:0.7979005575180054 F-measure:0.825840175151825
Batch:400/1255 acc:0.9201900959014893 pre:0.8449655771255493 recall:0.798370897769928 F-measure:0.8274264931678772
Batch:500/1255 acc:0.9181982278823853 pre:0.8502163290977478 recall:0.7945160865783691 F-measure:0.8304034471511841
Batch:600/1255 acc:0.9171571135520935 pre:0.8469772934913635 recall:0.7971896529197693 F-measure:0.8286811113357544
Batch:700/1255 acc:0.9192566871643066 pre:0.8457747101783752 recall:0.7970497012138367 F-measure:0.8279145956039429
Batch:800/1255 acc:0.9162204265594482 pre:0.8478636741638184 recall:0.7927688360214233 F-measure:0.8281283974647522
Batch:900/1255 acc:0.9169363975524902 pre:0.8351535797119141 recall:0.78256