In [None]:
import os
import time
from typing import List

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.models.alexnet import AlexNet

from dataset_models import DetectionDataset

In [None]:
# 指定数据集目录
image_path = os.path.abspath('datasets/JPEGImages/')
if not os.path.exists(image_path):
    raise Exception(f"{image_path} path does not exist.")

anno_path = os.path.abspath('datasets/Annotations/')
if not os.path.exists(anno_path):
    raise Exception(f"{anno_path} path does not exist.")

prop_dir = os.path.abspath('datasets/Proposals/')
if not os.path.exists(prop_dir):
    raise Exception(f"{prop_dir} path does not exist.")


In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


In [None]:
batch_size = 64
cpu_count = os.cpu_count()
num_workers = cpu_count if cpu_count else 1

In [None]:
dataset = DetectionDataset(image_path, prop_dir, transform)
loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=num_workers)


In [None]:
# load the model and evaluate it
model_path = './weights/Detection.pth'
model = AlexNet(num_classes=2)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()


In [None]:
match_list: List[int] = []
with torch.no_grad():
    for i, val_data in enumerate(loader):
        val_images, val_labels = val_data
        outputs = model(val_images)
        predict_y = torch.max(outputs, dim=1)[1]
        print(f'{i}/{len(loader)}')
        match_list.append([n + i * batch_size for n, eq in enumerate(torch.eq(predict_y, val_labels)) if eq])
        
