In [4]:
import time
import os
import torch
import pytorch_mask_rcnn as pmr
    
    
# ------------------ adjustable parameters ---------------------

use_cuda = True # choose to use GPU or not
visualize = False # choose to visualize evaluation results or not
val_num_samples = 100 # number of samples during test, betweem 1 to 1444
ckpt_path = '../checkpoint.pth' # path where to save the checkpoint.pth
data_dir = 'E:/PyTorch/data/VOC2012' # dataset directory

# ------------------ adjustable parameters ---------------------

classes = pmr.dataset.VOC_BBOX_LABEL_NAMES

device = torch.device('cuda' if torch.cuda.is_available() and use_cuda else 'cpu')
print('cuda: {}\nuse_cuda: {}\n{} GPU(s) available'.format(torch.cuda.is_available(), use_cuda, torch.cuda.device_count()))
print('\ndevice: {}'.format(device))

valset = pmr.VOCDataset(data_dir, 'val', False, device=device) # len=1444
indices = torch.randperm(len(valset)).tolist()
valset = torch.utils.data.Subset(valset, indices[:val_num_samples])

model = pmr.maskrcnn_resnet50(True, 21).to(device)

if os.path.exists(ckpt_path):
    checkpoint = torch.load(ckpt_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    del checkpoint
    torch.cuda.empty_cache()

cuda: True
use_cuda: True
1 GPU(s) available

device: cuda


In [5]:
since = time.time()

# ------------------ test ---------------------
# remove below '#' to activate ap_getter 

#ap_getter = pmr.APGetter(20, device)
model.eval()
for image in valset:
    with torch.no_grad():
        result = model(image)
        
    #ap_getter.collect_data(result, target)

    if visualize:
        print('  '.join(classes[l] for l in result['labels']))
        print('  '.join('{:.2f}'.format(p) for p in result['scores']))
        pmr.show(image, result)
        
#ap_getter.compute_ap()
#print('AP on iou 0.50:0.95')
#print('  '.join(str(round(i * 100, 1)) for i in ap_getter.AP_series))
#print('mAP: {:.1f}'.format(ap_getter.mAP * 100))

# ------------------ test ---------------------

print('\ntotal time: {:.2f} s'.format(time.time() - since))


total time: 350.78 s
