In [2]:
'''
Author       : Aditya Jain
Date Started : 18th August, 2021
About        : This file does inference on test images for DL-based localization
'''
import torch
import torchvision.models as torchmodels
import torchvision
import os
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision import transforms
from PIL import Image
import json
import cv2

  from .autonotebook import tqdm as notebook_tqdm


#### Model Loading

In [25]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

# load a model pre-trained pre-trained on COCO
model       = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 2  # 1 class (person) + background
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)


MODEL_PATH  = '/home/mila/a/aditya.jain/logs/v1_localizmodel_2021-08-17-12-06.pt'
checkpoint  = torch.load(MODEL_PATH, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

cuda


<All keys matched successfully>

In [29]:
model       = model.to(device)
model.eval()

DATA_PATH  = '/scratch/Localization/dl_test/orig_resized_images/'
SAVE_PATH  = '/scratch/Localization/dl_test/dl_bbox/'
SCORE_THR  = 0.99
image_list = os.listdir(DATA_PATH)

transform  = transforms.Compose([              
            transforms.ToTensor()])

for img in image_list:
    image_path = DATA_PATH + img
    image      = transform(Image.open(image_path))
    image_pred = torch.unsqueeze(image, 0).to(device)
    output     = model(image_pred)
    
    bboxes     = output[0]['boxes'][output[0]['scores'] > SCORE_THR]
    image_cv   = cv2.imread(image_path)
    
    for box in bboxes:
        box_numpy = box.detach().cpu().numpy()        
        cv2.rectangle(image_cv,(box_numpy[0], box_numpy[1]),(box_numpy[2], box_numpy[3]),(0,0,255),3)
     
    cv2.imwrite(SAVE_PATH + img, image_cv)    

In [3]:
with open('set2_maxim.json', 'r') as f:
    data = json.load(f)

In [6]:
with open("set2_maxim-kent.json", "w") as write_file:
    json.dump(data, write_file, indent=4)