In [None]:
import numpy as np
from torch.utils.data import DataLoader
import yaml
import torch
import cv2
from superpoint.data.Synthetic_dataset import SyntheticShapes
from superpoint.models.SuperPoint import SuperPoint
from superpoint.settings import CKPT_PATH
from utils import plot_imgs
import os.path as osp

In [None]:
config_path = osp.realpath(osp.join(osp.abspath('..'), 'configs\\magicpoint_syn.yaml'))

with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

config["pretrained"] = "\\magicpoint_syn_v1\\magicpoint_syn_v1.pth"
config["model"]["detector_head"]["det_thresh"] = 0.015
config["data"]["augmentation"]["homographic"]["enable_test"] = True

dataset = SyntheticShapes(config["data"], task="test", device="cpu")
data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=dataset.batch_collator)

model = SuperPoint(config["model"])

model_state_dict =  model.state_dict()                
pretrained_dict = torch.load(f'{CKPT_PATH}\{config["pretrained"]}',map_location="cpu")
pretrained_dict = pretrained_dict["model_state_dict"]
                
for k,v in pretrained_dict.items():
    if k in model_state_dict.keys():
        model_state_dict[k] = v
                
model.load_state_dict(model_state_dict)
model.to("cpu").eval()


In [None]:
def draw_keypoints(img, corners, color):
    keypoints = [cv2.KeyPoint(int(c[1]), int(c[0]), 0.5) for c in np.stack(corners).T]
    return cv2.drawKeypoints(img.astype(np.uint8), keypoints, None, color=color)
def draw_overlay(img, mask, color=[0, 0, 255], op=0.5):
    img[np.where(mask)] = img[np.where(mask)]*(1-op) + np.array(color)*op
def display(d,pred):
    img = draw_keypoints(d["raw"]['image'][0,...].cpu().squeeze(0).numpy() * 255, np.where(pred.cpu().squeeze(0).numpy()), (0, 255, 0))
    draw_overlay(img, np.logical_not(d["raw"]['valid_mask'].cpu().squeeze(0).numpy()), [255, 0, 0])
    return img

In [None]:
model.eval()
with torch.no_grad():
    for i,data in enumerate(data_loader):
        image = data["raw"]["image"]
        outputs = model(image)
        pred = outputs["detector_output"]["pred_pts"]
        plot_imgs([display(data,pred)/255.] , dpi=100)
        if i == 3:
            break