html
Note
Go to the end <sphx_glr_download_auto_examples_6.image_matching_by_QAP_plot_image_matching_pytorch.py>
to download the full example code
sphx-glr-example-title
This example shows how to match image keypoints by graph matching solvers provided by pygmtools
. These solvers follow the Quadratic Assignment Problem formulation and can generally work out-of-box. The matched images can be further processed for other downstream tasks.
Note
The following solvers support QAP formulation, and are included in this example:
~pygmtools.classic_solvers.rrwm
(classic solver)~pygmtools.classic_solvers.ipfp
(classic solver)~pygmtools.classic_solvers.sm
(classic solver)~pygmtools.neural_solvers.ngm
(neural network solver)
import torch # pytorch backend
import torchvision # CV models
import pygmtools as pygm
import matplotlib.pyplot as plt # for plotting
from matplotlib.patches import ConnectionPatch # for plotting matching result
import scipy.io as sio # for loading .mat file
import scipy.spatial as spa # for Delaunay triangulation
from sklearn.decomposition import PCA as PCAdimReduc
import itertools
import numpy as np
from PIL import Image
pygm.set_backend('pytorch') # set default backend for pygmtools
Images are from the Willow Object Class dataset (this dataset also available with the Benchmark of pygmtools
, see ~pygmtools.dataset.WillowObject
).
The images are resized to 256x256.
obj_resize = (256, 256)
img1 = Image.open('../data/willow_duck_0001.png')
img2 = Image.open('../data/willow_duck_0002.png')
kpts1 = torch.tensor(sio.loadmat('../data/willow_duck_0001.mat')['pts_coord'])
kpts2 = torch.tensor(sio.loadmat('../data/willow_duck_0002.mat')['pts_coord'])
kpts1[0] = kpts1[0] * obj_resize[0] / img1.size[0]
kpts1[1] = kpts1[1] * obj_resize[1] / img1.size[1]
kpts2[0] = kpts2[0] * obj_resize[0] / img2.size[0]
kpts2[1] = kpts2[1] * obj_resize[1] / img2.size[1]
img1 = img1.resize(obj_resize, resample=Image.BILINEAR)
img2 = img2.resize(obj_resize, resample=Image.BILINEAR)
Visualize the images and keypoints
def plot_image_with_graph(img, kpt, A=None):
plt.imshow(img)
plt.scatter(kpt[0], kpt[1], c='w', edgecolors='k')
if A is not None:
for idx in torch.nonzero(A, as_tuple=False):
plt.plot((kpt[0, idx[0]], kpt[0, idx[1]]), (kpt[1, idx[0]], kpt[1, idx[1]]), 'k-')
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('Image 1')
plot_image_with_graph(img1, kpts1)
plt.subplot(1, 2, 2)
plt.title('Image 2')
plot_image_with_graph(img2, kpts2)
/auto_examples/6.image_matching_by_QAP/images/sphx_glr_plot_image_matching_pytorch_001.png
Graph structures are built based on the geometric structure of the keypoint set. In this example, we refer to Delaunay triangulation.
We encode the length of edges as edge features
Visualize the graphs
/auto_examples/6.image_matching_by_QAP/images/sphx_glr_plot_image_matching_pytorch_002.png
Let's adopt the VGG16 CNN model to extract node features.
vgg16_cnn = torchvision.models.vgg16_bn(True)
torch_img1 = torch.from_numpy(np.array(img1, dtype=np.float32) / 256).permute(2, 0, 1).unsqueeze(0) # shape: BxCxHxW
torch_img2 = torch.from_numpy(np.array(img2, dtype=np.float32) / 256).permute(2, 0, 1).unsqueeze(0) # shape: BxCxHxW
with torch.set_grad_enabled(False):
feat1 = vgg16_cnn.features(torch_img1)
feat2 = vgg16_cnn.features(torch_img2)
Normalize the features
Up-sample the features to the original image size
Visualize the extracted CNN feature (dimensionality reduction via principle component analysis)
pca_dim_reduc = PCAdimReduc(n_components=3, whiten=True)
feat_dim_reduc = pca_dim_reduc.fit_transform(
np.concatenate((
feat1_upsample.permute(0, 2, 3, 1).reshape(-1, num_features).numpy(),
feat2_upsample.permute(0, 2, 3, 1).reshape(-1, num_features).numpy()
), axis=0)
)
feat_dim_reduc = feat_dim_reduc / np.max(np.abs(feat_dim_reduc), axis=0, keepdims=True) / 2 + 0.5
feat1_dim_reduc = feat_dim_reduc[:obj_resize[0] * obj_resize[1], :]
feat2_dim_reduc = feat_dim_reduc[obj_resize[0] * obj_resize[1]:, :]
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('Image 1 with CNN features')
plot_image_with_graph(img1, kpts1, A1)
plt.imshow(feat1_dim_reduc.reshape(obj_resize[1], obj_resize[0], 3), alpha=0.5)
plt.subplot(1, 2, 2)
plt.title('Image 2 with CNN features')
plot_image_with_graph(img2, kpts2, A2)
plt.imshow(feat2_dim_reduc.reshape(obj_resize[1], obj_resize[0], 3), alpha=0.5)
/auto_examples/6.image_matching_by_QAP/images/sphx_glr_plot_image_matching_pytorch_003.png
sphx-glr-script-out
<matplotlib.image.AxesImage object at 0x7fd85c402770>
Extract node features by nearest interpolation
We follow the formulation of Quadratic Assignment Problem (QAP):
where the first step is to build the affinity matrix (K)
Visualization of the affinity matrix. For graph matching problem with N nodes, the affinity matrix has N2 × N2 elements because there are N2 edges in each graph.
Note
The diagonal elements are node affinities, the off-diagonal elements are edge features.
/auto_examples/6.image_matching_by_QAP/images/sphx_glr_plot_image_matching_pytorch_004.png
sphx-glr-script-out
<matplotlib.image.AxesImage object at 0x7fd85c4625f0>
See ~pygmtools.classic_solvers.rrwm
for the API reference.
The output of RRWM is a soft matching matrix. Hungarian algorithm is then adopted to reach a discrete matching matrix.
The correct matchings are marked by green, and wrong matchings are marked by red. In this example, the nodes are ordered by their ground truth classes (i.e. the ground truth matching matrix is a diagonal matrix).
plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by RRWM')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
for i in range(X.shape[0]):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="red" if i != j else "green")
plt.gca().add_artist(con)
/auto_examples/6.image_matching_by_QAP/images/sphx_glr_plot_image_matching_pytorch_005.png
We could also do a quick benchmarking of other solvers on this specific problem.
See ~pygmtools.classic_solvers.ipfp
for the API reference.
X = pygm.ipfp(K, kpts1.shape[1], kpts2.shape[1])
plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by IPFP')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
for i in range(X.shape[0]):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="red" if i != j else "green")
plt.gca().add_artist(con)
/auto_examples/6.image_matching_by_QAP/images/sphx_glr_plot_image_matching_pytorch_006.png
See ~pygmtools.classic_solvers.sm
for the API reference.
X = pygm.sm(K, kpts1.shape[1], kpts2.shape[1])
X = pygm.hungarian(X)
plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by SM')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
for i in range(X.shape[0]):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="red" if i != j else "green")
plt.gca().add_artist(con)
/auto_examples/6.image_matching_by_QAP/images/sphx_glr_plot_image_matching_pytorch_007.png
See ~pygmtools.neural_solvers.ngm
for the API reference.
Note
The NGM solvers are pretrained on a different problem setting, so their performance may seem inferior. To improve their performance, you may change the way of building affinity matrices, or try finetuning NGM on the new problem.
The NGM solver pretrained on Willow dataset:
X = pygm.ngm(K, kpts1.shape[1], kpts2.shape[1], pretrain='willow')
X = pygm.hungarian(X)
plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by NGM (willow pretrain)')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
for i in range(X.shape[0]):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="red" if i != j else "green")
plt.gca().add_artist(con)
/auto_examples/6.image_matching_by_QAP/images/sphx_glr_plot_image_matching_pytorch_008.png
The NGM solver pretrained on VOC dataset:
X = pygm.ngm(K, kpts1.shape[1], kpts2.shape[1], pretrain='voc')
X = pygm.hungarian(X)
plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by NGM (voc pretrain)')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
for i in range(X.shape[0]):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="red" if i != j else "green")
plt.gca().add_artist(con)
/auto_examples/6.image_matching_by_QAP/images/sphx_glr_plot_image_matching_pytorch_009.png
sphx-glr-timing
Total running time of the script: (0 minutes 3.513 seconds)
html
Download Python source code: plot_image_matching_pytorch.py <plot_image_matching_pytorch.py>
Download Jupyter notebook: plot_image_matching_pytorch.ipynb <plot_image_matching_pytorch.ipynb>
html
sphx-glr-signature