In [None]:
import os
import time

import torch
from torch.utils.data import DataLoader
from torchvision.models.alexnet import AlexNet
from torchvision.transforms import Compose, Normalize, Resize, ToTensor

from dataset_models import BndboxDataset
from selective_search import selective_search


In [None]:
# 指定数据集目录
image_path = os.path.abspath('datasets/JPEGImages/IMG_000001.jpg')
if not os.path.exists(image_path):
    raise Exception(f"{image_path} path does not exist.")


In [None]:
# load the model and evaluate it
model_path = './weights/Alexnet.pth'
model = AlexNet(num_classes=94)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()


In [None]:
rects = selective_search(image_path)
len(rects)

In [None]:
data_transform = Compose([Resize((224, 224)),
                          ToTensor(),
                          Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
validate_dataset = BndboxDataset(image_path, rects, data_transform)


In [None]:
validate_loader = DataLoader(
    validate_dataset, batch_size=4, shuffle=False, num_workers=os.cpu_count())
len(validate_dataset)


In [None]:

with torch.no_grad():
    for val_data in validate_loader:
        begin = time.time()
        outputs = model(val_data)
        predict_y = torch.max(outputs, dim=1)
        end = time.time()
        print(f'cost time: {end - begin}s')
        print(predict_y)
        