diff --git a/match_two.py b/match_two.py index 076947b..700d0c4 100755 --- a/match_two.py +++ b/match_two.py @@ -95,7 +95,10 @@ def match_two(model, device, config, im_one, im_two, plot_save_path): model.eval() - it = input_transform((int(config['feature_extract']['imageresizeH']), int(config['feature_extract']['imageresizeW']))) + crop_roi = None + if 'imagecrop' in config['feature_extract']: + crop_roi = tuple([int(x) for x in config['feature_extract']['imagecrop'].split(",")]) + it = input_transform((int(config['feature_extract']['imageresizeH']), int(config['feature_extract']['imageresizeW'])), crop_roi) im_one_pil = Image.fromarray(cv2.cvtColor(im_one, cv2.COLOR_BGR2RGB)) im_two_pil = Image.fromarray(cv2.cvtColor(im_two, cv2.COLOR_BGR2RGB)) @@ -148,6 +151,10 @@ def match_two(model, device, config, im_one, im_two, plot_save_path): tqdm.write('====> Plotting Local Features and save them to ' + str(join(plot_save_path, 'patchMatchings.png'))) # using cv2 for their in-built keypoint correspondence plotting tools + if crop_roi is not None: + top, left, bottom, right = crop_roi + im_one = im_one[top:bottom,left:right] + im_two = im_two[top:bottom,left:right] cv_im_one = cv2.resize(im_one, (int(config['feature_extract']['imageresizeW']), int(config['feature_extract']['imageresizeH']))) cv_im_two = cv2.resize(im_two, (int(config['feature_extract']['imageresizeW']), int(config['feature_extract']['imageresizeH']))) # cv2 resize slightly different from torch, but for visualisation only not a big problem diff --git a/patchnetvlad/tools/datasets.py b/patchnetvlad/tools/datasets.py index 845f16e..59a7005 100644 --- a/patchnetvlad/tools/datasets.py +++ b/patchnetvlad/tools/datasets.py @@ -27,6 +27,7 @@ import os import torchvision.transforms as transforms +import torchvision.transforms.functional as TF import torch.utils.data as data import numpy as np @@ -36,20 +37,19 @@ from patchnetvlad.tools import PATCHNETVLAD_ROOT_DIR -def input_transform(resize=(480, 640)): +def input_transform(resize=(480, 640), crop_roi=None): + trans = [] + if crop_roi is not None and len(crop_roi) == 4 and all([x >= 0 for x in crop_roi]): + top, left, bottom, right = crop_roi + trans.append(transforms.Lambda(lambda x: TF.crop(x, top, left, bottom-top, right-left))) if resize[0] > 0 and resize[1] > 0: - return transforms.Compose([ - transforms.Resize(resize), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]), - ]) - else: - return transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]), - ]) + trans.append(transforms.Resize(resize)) + trans.extend([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + ]) + return transforms.Compose(trans) class PlaceDataset(data.Dataset): @@ -80,8 +80,11 @@ def __init__(self, query_file_path, index_file_path, dataset_root_dir, ground_tr self.positives = None self.distances = None + crop_roi = None + if 'imagecrop' in config: + crop_roi = tuple([int(x) for x in config['imagecrop'].split(',')]) self.resize = (int(config['imageresizeH']), int(config['imageresizeW'])) - self.mytransform = input_transform(self.resize) + self.mytransform = input_transform(self.resize, crop_roi) def __getitem__(self, index): img = Image.open(self.images[index])