In [1]:
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

from data import LoadTest
from model import VGG16
from utils import *

for name in (torch, torchvision, cv2, np):
    print(name.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [3]:
path_image = "./DUTS/DUTS-TE/DUTS-TE-Image/"
path_mask = "./DUTS/DUTS-TE/DUTS-TE-Mask/"

# path_image = "./DUT-OMROM/DUT-OMRON-image/"
# path_mask = "./DUT-OMROM/pixelwiseGT-new-PNG/"

batch_size = 4 #受限于贫穷，4是极限了
target_size = 256

In [4]:
data_loader = data.DataLoader(LoadTest(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=False)

In [5]:
total_batch = len(data_loader)
total_batch

1255

In [6]:
model = VGG16()
model.load_state_dict(torch.load("./model/model_62_edge.pth"), strict=False)

<All keys matched successfully>

In [7]:
model.eval()
model.to(device)
print()




In [8]:
for layer in model.parameters():
    layer.requires_grad = False

In [9]:

start_time = time.time()

total_loss = 0
total_acc = 0
total_pre = 0
total_rec = 0
total_f_score = 0

for batch_n, (image, mask) in enumerate(data_loader, start=1):

    image = image.to(device)
    mask = mask.to(device)

    predict = model(image)

    with torch.no_grad():
        acc = accuracy(predict, mask)
        total_acc += acc

        pre = precision(predict, mask)
        total_pre += pre

        rec = recall(predict, mask)
        total_rec += rec

        f_score = F_Measure(pre, rec)
        total_f_score += f_score


    if batch_n % 100 == 0:
        with torch.no_grad():
            avg_acc = total_acc / batch_n
            avg_pre = total_pre / batch_n
            avg_rec = total_rec / batch_n
            avg_f_score = total_f_score / batch_n
            print("Batch:{}/{}".format( batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}"
                  .format(avg_acc, avg_pre, avg_rec, avg_f_score))
end_time = time.time()
print("--------------------------------------------------------------")
print("time:{:.2f}s END : acc:{} pre:{} rec:{} F-measure:{}"
      .format(end_time - start_time, 
              total_acc / total_batch,
              total_pre / total_batch,
              total_rec / total_batch,
              total_f_score / total_batch))
print("--------------------------------------------------------------")

Batch:100/1255 acc:0.9204647541046143 pre:0.8331215381622314 recall:0.8159812688827515 F-measure:0.8240495920181274
Batch:200/1255 acc:0.9231460094451904 pre:0.8335710763931274 recall:0.8082403540611267 F-measure:0.8220932483673096
Batch:300/1255 acc:0.9217603802680969 pre:0.8323397636413574 recall:0.8048893809318542 F-measure:0.8199082612991333
Batch:400/1255 acc:0.9191410541534424 pre:0.8342349529266357 recall:0.8044100999832153 F-measure:0.8209814429283142
Batch:500/1255 acc:0.9172236919403076 pre:0.8396075367927551 recall:0.8000918626785278 F-measure:0.8240416646003723
Batch:600/1255 acc:0.9164923429489136 pre:0.8386304378509521 recall:0.8024688959121704 F-measure:0.8239547610282898
Batch:700/1255 acc:0.9184520244598389 pre:0.8370122909545898 recall:0.8029640913009644 F-measure:0.8228388428688049
Batch:800/1255 acc:0.9154986143112183 pre:0.840238094329834 recall:0.7978821396827698 F-measure:0.8236758708953857
Batch:900/1255 acc:0.9159259796142578 pre:0.8249093890190125 recall:0.787