Skip to content

Latest commit



731 lines (471 loc) · 22.1 KB


File metadata and controls

731 lines (471 loc) · 22.1 KB
.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

PyTorch Backend Example: Multi-Graph Matching

This example shows how to match multiple graphs. Multi-graph matching means that more than two graphs are jointly matched.

# Author: Zetian Jiang <>
#         Ziao Guo <>
# License: Mulan PSL v2 License


The following solvers are included in this example:

import os
import time
import math
import copy
import torch  # pytorch backend
import itertools
import numpy as np
import pygmtools as pygm
import matplotlib.pyplot as plt  # for plotting
import as sio  # for loading .mat file
import scipy.spatial as spa  # for Delaunay triangulation

from PIL import Image
from matplotlib.patches import ConnectionPatch # for plotting matching result

pygm.set_backend('pytorch')  # set default backend for pygmtools

Load the images

Images are from the Willow Object Class dataset (this dataset also available with the Benchmark of pygmtools, see :class:`~pygmtools.dataset.WillowObject`).

The images are resized to 256x256.

obj_resize = (256, 256)
n_images = 30
n_outlier = 0
img_list = []
kpts_list = []
n_kpts_list = []
perm_list = []

bm = pygm.benchmark.Benchmark(name='WillowObject',

while len(img_list) < n_images:
    data_list, gt_dict, _ = bm.rand_get_data(cls='Car')
    for data in data_list:
        img = Image.fromarray(data['img'])
        coords = sorted(data['kpts'], key=lambda x: x['labels'])
        kpts = torch.tensor([[kpt['x'] for kpt in coords],
                        [kpt['y'] for kpt in coords]])
        perm = np.eye(kpts.shape[1])

Visualize the images and keypoints

def plot_image_with_graph(img, kpt, A=None):
    plt.scatter(kpt[0], kpt[1], c='w', edgecolors='k')
    if A is not None:
        for idx in torch.nonzero(A, as_tuple=False):
            plt.plot((kpt[0, idx[0]], kpt[0, idx[1]]), (kpt[1, idx[0]], kpt[1, idx[1]]), 'k-')

plt.figure(figsize=(20, 18))
for i in range(n_images):
    plt.subplot(5, n_images // 5, i + 1)
    plt.title('Image {}'.format(i + 1))
    plot_image_with_graph(img_list[i], kpts_list[i])
# plt.savefig('image')
# plt.close()
.. image-sg:: /auto_examples/4.multi-graph_matching/images/sphx_glr_plot_multi_graph_match_pytorch_001.png
   :alt: Image 1, Image 2, Image 3, Image 4, Image 5, Image 6, Image 7, Image 8, Image 9, Image 10, Image 11, Image 12, Image 13, Image 14, Image 15, Image 16, Image 17, Image 18, Image 19, Image 20, Image 21, Image 22, Image 23, Image 24, Image 25, Image 26, Image 27, Image 28, Image 29, Image 30
   :srcset: /auto_examples/4.multi-graph_matching/images/sphx_glr_plot_multi_graph_match_pytorch_001.png
   :class: sphx-glr-single-img

Build the graphs

Graph structures are built based on the geometric structure of the keypoint set. In this example, we refer to Delaunay triangulation.

def delaunay_triangulation(kpt):
    d = spa.Delaunay(kpt.numpy().transpose())
    A = torch.zeros(len(kpt[0]), len(kpt[0]))
    for simplex in d.simplices:
        for pair in itertools.permutations(simplex, 2):
            A[pair] = 1
    return A

adj_list = []
for i in range(n_images):
    A = delaunay_triangulation(kpts_list[i])

Build affinity matrix

We follow the formulation of Quadratic Assignment Problem (QAP):

&\max_{\mathbf{X}} \ \texttt{vec}(\mathbf{X})^\top \mathbf{K} \texttt{vec}(\mathbf{X})\\
s.t. \quad &\mathbf{X} \in \{0, 1\}^{n_1\times n_2}, \ \mathbf{X}\mathbf{1} = \mathbf{1}, \ \mathbf{X}^\top\mathbf{1} \leq \mathbf{1}

where the first step is to build the affinity matrix (\mathbf{K}) for each pair of graphs

def get_feature(n, points, adj):
    :param n: points # of graph
    :param points: torch tensor, (n, 2)
    :param adj: torch tensor, (n, n)
    :return: edge feat, angle feat
    points_1 = points.reshape(n, 1, 2).repeat(1, n, 1)
    points_2 = points.reshape(1, n, 2).repeat(n, 1, 1)
    edge_feat = torch.sqrt(torch.sum((points_1 - points_2) ** 2, dim=2))
    edge_feat = edge_feat / torch.max(edge_feat)
    angle_feat = torch.atan((points_1[:, :, 1] - points_2[:, :, 1]) / (points_1[:, :, 0] - points_2[:, :, 0] + 1e-8))
    angle_feat = 2 * angle_feat / math.pi

    return edge_feat, angle_feat

def get_pair_affinity(edge_feat_1, angle_feat_1, edge_feat_2, angle_feat_2, adj1, adj2):
    n1, n2 = edge_feat_1.shape[0], edge_feat_2.shape[0]
    assert n1 == angle_feat_1.shape[0] and n2 == angle_feat_2.shape[0]

    left_adj = adj1.reshape(n1, n1, 1, 1).repeat(1, 1, n2, n2)
    right_adj = adj2.reshape(1, 1, n2, n2).repeat(n1, n1, 1, 1)
    adj = left_adj * right_adj

    left_edge_feat = edge_feat_1.reshape(n1, n1, 1, 1, -1).repeat(1, 1, n2, n2, 1)
    right_edge_feat = edge_feat_2.reshape(1, 1, n2, n2, -1).repeat(n1, n1, 1, 1, 1)
    edge_weight = torch.sqrt(torch.sum((left_edge_feat - right_edge_feat) ** 2, dim=-1))

    left_angle_feat = angle_feat_1.reshape(n1, n1, 1, 1, -1).repeat(1, 1, n2, n2, 1)
    right_angle_feat = angle_feat_2.reshape(1, 1, n2, n2, -1).repeat(n1, n1, 1, 1, 1)
    angle_weight = torch.sqrt(torch.sum((left_angle_feat - right_angle_feat) ** 2, dim=-1))

    affinity = edge_weight * 0.9 + angle_weight * 0.1
    affinity = torch.exp(-affinity / 0.1) * adj
    affinity = affinity.transpose(1, 2)

    return affinity

def generate_affinity_matrix(n_points, points_list, adj_list):
    m = len(n_points)
    n_max = max(n_points)
    affinity = torch.zeros(m, m, n_max, n_max, n_max, n_max)

    edge_feat_list = []
    angle_feat_list = []
    for n, points, adj in zip(n_points, points_list, adj_list):
        edge_feat, angle_feat = get_feature(n, points, adj)

    for i, j in itertools.product(range(m), range(m)):
        pair_affinity = get_pair_affinity(edge_feat_list[i],
        affinity[i, j] = pair_affinity

    affinity = affinity.permute(0, 1, 3, 2, 5, 4).reshape(m, m, n_max * n_max, n_max * n_max)
    return affinity

affinity_mat = generate_affinity_matrix(n_kpts_list, kpts_list, adj_list)

m = len(kpts_list)
n = int(torch.max(torch.tensor(n_kpts_list)))
ns_src = torch.ones(m * m).int() * n
ns_tgt = torch.ones(m * m).int() * n

Calculate accuracy, consistency, and affinity

def cal_accuracy(mat, gt_mat, n):
    m = mat.shape[0]
    acc = 0
    for i in range(m):
        for j in range(m):
            _mat, _gt_mat = mat[i, j], gt_mat[i, j]
            row_sum = torch.sum(_gt_mat, dim=0)
            col_sum = torch.sum(_gt_mat, dim=1)
            row_idx = [k for k in range(n) if row_sum[k] != 0]
            col_idx = [k for k in range(n) if col_sum[k] != 0]
            _mat = _mat[row_idx, :]
            _mat = _mat[:, col_idx]
            _gt_mat = _gt_mat[row_idx, :]
            _gt_mat = _gt_mat[:, col_idx]
            acc += 1 - torch.sum(torch.abs(_mat - _gt_mat)) / 2 / (n - n_outlier)
    return acc / (m * m)

def cal_consistency(mat, gt_mat, m, n):
    return torch.mean(get_batch_pc_opt(mat))

def cal_affinity(X, X_gt, K, m, n):
    X_batch = X.reshape(-1, n, n)
    X_gt_batch = X_gt.reshape(-1, n, n)
    K_batch = K.reshape(-1, n * n, n * n)
    affinity = get_batch_affinity(X_batch, K_batch)
    affinity_gt = get_batch_affinity(X_gt_batch, K_batch)
    return torch.mean(affinity / (affinity_gt + 1e-8))

def get_batch_affinity(X, K, norm=1):
    calculate affinity score
    :param X: (b, n, n)
    :param K: (b, n*n, n*n)
    :param norm: normalization term
    :return: affinity_score (b, 1, 1)
    b, n, _ = X.size()
    vx = X.transpose(1, 2).reshape(b, -1, 1)  # (b, n*n, 1)
    vxt = vx.transpose(1, 2)  # (b, 1, n*n)
    affinity = torch.bmm(torch.bmm(vxt, K), vx) / norm
    return affinity

def get_single_affinity(X, K, norm=1):
    calculate affinity score
    :param X: (n, n)
    :param K: (n*n, n*n)
    :param norm: normalization term
    :return: affinity_score scale
    n, _ = X.size()
    vx = X.transpose(0, 1).reshape(-1, 1)
    vxt = vx.transpose(0, 1)
    affinity = torch.matmul(torch.matmul(vxt, K), vx) / norm
    return affinity

def get_single_pc(X, i, j, Xij=None):
    :param X: (m, m, n, n) all the matching results
    :param i: index
    :param j: index
    :param Xij: (n, n) matching
    :return: the consistency of X_ij
    m, _, n, _ = X.size()
    if Xij is None:
        Xij = X[i, j]
    pair_con = 0
    for k in range(m):
        X_combo = torch.matmul(X[i, k], X[k, j])
        pair_con += torch.sum(torch.abs(Xij - X_combo)) / (2 * n)
    return 1 - pair_con / m

def get_single_pc_opt(X, i, j, Xij=None):
    :param X: (m, m, n, n) all the matching results
    :param i: index
    :param j: index
    :return: the consistency of X_ij
    m, _, n, _ = X.size()
    if Xij is None:
        Xij = X[i, j]
    X1 = X[i, :].reshape(-1, n, n)
    X2 = X[:, j].reshape(-1, n, n)
    X_combo = torch.bmm(X1, X2)
    pair_con = 1 - torch.sum(torch.abs(Xij - X_combo)) / (2 * n * m)
    return pair_con

def get_batch_pc(X):
    :param X: (m, m, n, n) all the matching results
    :return: (m, m) the consistency of X
    pair_con = torch.zeros(m, m).cuda()
    for i in range(m):
        for j in range(m):
            pair_con[i, j] = get_single_pc_opt(X, i, j)
    return pair_con

def get_batch_pc_opt(X):
    :param X: (m, m, n, n) all the matching results
    :return: (m, m) the consistency of X
    m, _, n, _ = X.size()
    X1 = X.reshape(m, 1, m, n, n).repeat(1, m, 1, 1, 1).reshape(-1, n, n)  # X1[i, j, k] = X[i, k]
    X2 = X.reshape(1, m, m, n, n).repeat(m, 1, 1, 1, 1).transpose(1, 2).reshape(-1, n, n)  # X2[i, j, k] = X[k, j]
    X_combo = torch.bmm(X1, X2).reshape(m, m, m, n, n)
    X_ori = X.reshape(m, m, 1, n, n).repeat(1, 1, m, 1, 1)
    pair_con = 1 - torch.sum(torch.abs(X_combo - X_ori), dim=(2, 3, 4)) / (2 * n * m)
    return pair_con

def eval(mat, gt_mat, affinity, m, n):
    acc = cal_accuracy(mat, gt_mat, n)
    src = cal_affinity(mat, gt_mat, affinity, m, n)
    con = cal_consistency(mat, gt_mat, m, n)
    return acc, src, con

Generate gt mat

gt_mat = torch.zeros(m, m, n, n)
for i in range(m):
    for j in range(m):
        gt_mat[i, j] = torch.tensor(np.matmul(perm_list[i].transpose(0, 1), perm_list[j]))
# print(perm_list[0])
# print(perm_list[1])
# print(gt_mat[1, 2])
# print(gt_mat[0, 1] - gt_mat[1, 0].transpose(0, 1))

Pairwise graph matching by RRWM

See :func:`~pygmtools.classic_solvers.rrwm` for the API reference.

a = 0
b = 12
tic = time.time()
rrwm_mat = pygm.classic_solvers.rrwm(affinity_mat.reshape(-1, n * n, n * n), ns_src, ns_tgt)
rrwm_mat = pygm.linear_solvers.hungarian(rrwm_mat)
toc = time.time()
rrwm_mat = rrwm_mat.reshape(m, m, n, n)
rrwm_acc, rrwm_src, rrwm_con = eval(rrwm_mat, gt_mat, affinity_mat, m, n)
rrwm_tim = toc - tic

plt.figure(figsize=(8, 4))
plt.suptitle('Multi-Graph Matching Result by RRWM')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img_list[a], kpts_list[a], adj_list[a])
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img_list[b], kpts_list[b], adj_list[b])
X = rrwm_mat[a, b]
for i in range(X.shape[0]):
    j = torch.argmax(X[i]).item()
    con = ConnectionPatch(xyA=kpts_list[a][:, i], xyB=kpts_list[b][:, j], coordsA="data", coordsB="data",
                          axesA=ax1, axesB=ax2, color="red" if i != j else "green")
# plt.savefig("RRWM.png")
# plt.close()
.. image-sg:: /auto_examples/4.multi-graph_matching/images/sphx_glr_plot_multi_graph_match_pytorch_002.png
   :alt: Multi-Graph Matching Result by RRWM
   :srcset: /auto_examples/4.multi-graph_matching/images/sphx_glr_plot_multi_graph_match_pytorch_002.png
   :class: sphx-glr-single-img

Multi graph matching by multi-graph solvers

Multi graph matching: CAO-M See :func:`~pygmtools.multi_graph_solvers.cao` for the API reference.
base_mat = copy.deepcopy(rrwm_mat)
tic = time.time()
cao_m_mat = pygm.multi_graph_solvers.cao(affinity_mat, base_mat, mode='memory')
cao_m_mat = pygm.linear_solvers.hungarian(cao_m_mat.reshape(-1, n, n)).reshape(m, m, n, n)
toc = time.time()
cao_m_acc, cao_m_src, cao_m_con = eval(cao_m_mat, gt_mat, affinity_mat, m, n)
cao_m_tim = toc - tic + rrwm_tim

plt.figure(figsize=(8, 4))
plt.suptitle('Multi-Graph Matching Result by CAO-M')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img_list[a], kpts_list[a], adj_list[a])
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img_list[b], kpts_list[b], adj_list[b])
X = cao_m_mat[a, b]
for i in range(X.shape[0]):
    j = torch.argmax(X[i]).item()
    con = ConnectionPatch(xyA=kpts_list[a][:, i], xyB=kpts_list[b][:, j], coordsA="data", coordsB="data",
                          axesA=ax1, axesB=ax2, color="red" if i != j else "green")
# plt.savefig("CAO-M.png")
# plt.close()
.. image-sg:: /auto_examples/4.multi-graph_matching/images/sphx_glr_plot_multi_graph_match_pytorch_003.png
   :alt: Multi-Graph Matching Result by CAO-M
   :srcset: /auto_examples/4.multi-graph_matching/images/sphx_glr_plot_multi_graph_match_pytorch_003.png
   :class: sphx-glr-single-img

Multi graph matching: CAO-T See :func:`~pygmtools.multi_graph_solvers.cao` for the API reference.

base_mat = copy.deepcopy(rrwm_mat)
tic = time.time()
cao_t_mat = pygm.multi_graph_solvers.cao(affinity_mat, base_mat, mode='time')
cao_t_mat = pygm.linear_solvers.hungarian(cao_t_mat.reshape(-1, n, n)).reshape(m, m, n, n)
toc = time.time()
cao_t_acc, cao_t_src, cao_t_con = eval(cao_t_mat, gt_mat, affinity_mat, m, n)
cao_t_tim = toc - tic + rrwm_tim

plt.figure(figsize=(8, 4))
plt.suptitle('Multi-Graph Matching Result by CAO-T')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img_list[a], kpts_list[a], adj_list[a])
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img_list[b], kpts_list[b], adj_list[b])
X = cao_t_mat[a, b]
for i in range(X.shape[0]):
    j = torch.argmax(X[i]).item()
    con = ConnectionPatch(xyA=kpts_list[a][:, i], xyB=kpts_list[b][:, j], coordsA="data", coordsB="data",
                          axesA=ax1, axesB=ax2, color="red" if i != j else "green")
# plt.savefig("CAO-T.png")
# plt.close()
.. image-sg:: /auto_examples/4.multi-graph_matching/images/sphx_glr_plot_multi_graph_match_pytorch_004.png
   :alt: Multi-Graph Matching Result by CAO-T
   :srcset: /auto_examples/4.multi-graph_matching/images/sphx_glr_plot_multi_graph_match_pytorch_004.png
   :class: sphx-glr-single-img

Multi graph matching: MGM-Floyd-M See :func:`~pygmtools.multi_graph_solvers.mgm_floyd` for the API reference.

base_mat = copy.deepcopy(rrwm_mat)
tic = time.time()
floyd_m_mat = pygm.multi_graph_solvers.mgm_floyd(affinity_mat, base_mat, param_lambda=0.4, mode='memory')
floyd_m_mat = pygm.linear_solvers.hungarian(floyd_m_mat.reshape(-1, n, n)).reshape(m, m, n, n)
toc = time.time()
floyd_m_acc, floyd_m_src, floyd_m_con = eval(floyd_m_mat, gt_mat, affinity_mat, m, n)
floyd_m_tim = toc - tic + rrwm_tim

plt.figure(figsize=(8, 4))
plt.suptitle('Multi-Graph Matching Result by Floyd-M')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img_list[a], kpts_list[a], adj_list[a])
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img_list[b], kpts_list[b], adj_list[b])
X = floyd_m_mat[a, b]
for i in range(X.shape[0]):
    j = torch.argmax(X[i]).item()
    con = ConnectionPatch(xyA=kpts_list[a][:, i], xyB=kpts_list[b][:, j], coordsA="data", coordsB="data",
                          axesA=ax1, axesB=ax2, color="red" if i != j else "green")
# plt.savefig("Floyd-M.png")
# plt.close()
.. image-sg:: /auto_examples/4.multi-graph_matching/images/sphx_glr_plot_multi_graph_match_pytorch_005.png
   :alt: Multi-Graph Matching Result by Floyd-M
   :srcset: /auto_examples/4.multi-graph_matching/images/sphx_glr_plot_multi_graph_match_pytorch_005.png
   :class: sphx-glr-single-img

Multi graph matching: MGM-Floyd-T See :func:`~pygmtools.multi_graph_solvers.mgm_floyd` for the API reference.

base_mat = copy.deepcopy(rrwm_mat)
tic = time.time()
floyd_t_mat = pygm.multi_graph_solvers.mgm_floyd(affinity_mat, base_mat, param_lambda=0.6, mode='time')
floyd_t_mat = pygm.linear_solvers.hungarian(floyd_t_mat.reshape(-1, n, n)).reshape(m, m, n, n)
toc = time.time()
floyd_t_acc, floyd_t_src, floyd_t_con = eval(floyd_t_mat, gt_mat, affinity_mat, m, n)
floyd_t_tim = toc - tic + rrwm_tim

plt.figure(figsize=(8, 4))
plt.suptitle('Multi-Graph Matching Result by Floyd-T')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img_list[a], kpts_list[a], adj_list[a])
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img_list[b], kpts_list[b], adj_list[b])
X = floyd_t_mat[a, b]
for i in range(X.shape[0]):
    j = torch.argmax(X[i]).item()
    con = ConnectionPatch(xyA=kpts_list[a][:, i], xyB=kpts_list[b][:, j], coordsA="data", coordsB="data",
                          axesA=ax1, axesB=ax2, color="red" if i != j else "green")
# plt.savefig("Floyd-T.png")
# plt.close()
.. image-sg:: /auto_examples/4.multi-graph_matching/images/sphx_glr_plot_multi_graph_match_pytorch_006.png
   :alt: Multi-Graph Matching Result by Floyd-T
   :srcset: /auto_examples/4.multi-graph_matching/images/sphx_glr_plot_multi_graph_match_pytorch_006.png
   :class: sphx-glr-single-img

.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 22.959 seconds)

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: <>`

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_multi_graph_match_pytorch.ipynb <plot_multi_graph_match_pytorch.ipynb>`

.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <>`_