In [2]:
import sys

sys.path.append("..")
import matplotlib.pyplot as plt
import pygmtools as pygm
import seaborn as sns

from src.networks.gcn_net import GCN_Net

pygm.set_backend("pytorch")
dataset_path = "../data/match_dataset/"

In [3]:
import torch_geometric.loader as GraphLoader

from src.match_dataset import MatchDataset


def init():
    print("Loading dataset...")
    val_dataset = MatchDataset(f"{dataset_path}/validate_parts.json", dataset_path)
    val_loader = GraphLoader.DataLoader(val_dataset, batch_size=1, shuffle=False)
    print("Dataset loaded.")
    return val_loader


val_loader = init()

Loading dataset...
Dataset loaded.


In [4]:
ego_preds, cav_preds, K, gt = next(iter(val_loader))

In [5]:
import torch

pygm.set_backend("pytorch")
from src.utils.lap import build_affinity_matrix

ego_preds = ego_preds.squeeze()
cav_preds = cav_preds.squeeze()
K = K.squeeze()
K1, _, _ = build_affinity_matrix(cav_preds.numpy(), ego_preds.numpy())
torch.allclose(K, K1)

True

In [None]:
def build_affinity_matrix(
    node_aff_mat: torch.Tensor,
    edge_aff_mat: torch.Tensor,
    graph1_edges: torch.Tensor,
    graph2_edges: torch.Tensor,
) -> torch.Tensor:
    """构建二阶亲和矩阵

    Args:
        node_aff_mat: 节点相似度矩阵，形状 (num_nodes1, num_nodes2)
        edge_aff_mat: 边相似度矩阵，形状 (num_edges1, num_edges2)
        graph1_edges: 图1的边连接关系，形状 (num_edges1, 2)
        graph2_edges: 图2的边连接关系，形状 (num_edges2, 2)
    Returns:
        affinity_matrix: 二阶亲和矩阵，形状 (num_nodes1*num_nodes2, num_nodes1*num_nodes2)
    """
    device = edge_aff_mat.device if edge_aff_mat is not None else node_aff_mat.device
    dtype = edge_aff_mat.dtype if edge_aff_mat is not None else node_aff_mat.dtype
    num_nodes1, num_nodes2 = node_aff_mat.shape
    num_edges1, num_edges2 = edge_aff_mat.shape

    # 初始化二阶亲和矩阵K
    affinity_matrix = torch.zeros(
        num_nodes2, num_nodes1, num_nodes2, num_nodes1, dtype=dtype, device=device
    )

    # 处理边的亲和度
    if edge_aff_mat is not None:
        # 构建边的索引矩阵
        edge_indices = _build_edge_indices(
            graph1_edges[:num_edges1], graph2_edges[:num_edges2], num_edges1, num_edges2
        )
        # 填充边的亲和度值
        affinity_matrix[edge_indices] = edge_aff_mat[:num_edges1, :num_edges2].reshape(
            -1
        )

    # 重塑为方阵
    affinity_matrix = affinity_matrix.reshape(
        num_nodes2 * num_nodes1, num_nodes2 * num_nodes1
    )

    # 处理节点的亲和度
    if node_aff_mat is not None:
        diagonal = torch.diagonal(affinity_matrix)
        diagonal[:] = node_aff_mat.t().reshape(-1)

    return affinity_matrix


def _build_edge_indices(
    edges1: torch.Tensor, edges2: torch.Tensor, num_edges1: int, num_edges2: int
) -> tuple[torch.Tensor, ...]:
    """构建边的索引矩阵

    Args:
        edges1: 图1的边，形状 (num_edges1, 2)
        edges2: 图2的边，形状 (num_edges2, 2)
        num_edges1: 图1的边数
        num_edges2: 图2的边数

    Returns:
        edge_indices: 边索引元组 (start_g2, start_g1, end_g2, end_g1)
    """
    combined_edges = torch.cat(
        [edges1.repeat_interleave(num_edges2, dim=0), edges2.repeat(num_edges1, 1)],
        dim=1,
    )

    return (
        combined_edges[:, 2],  # start_g2
        combined_edges[:, 0],  # start_g1
        combined_edges[:, 3],  # end_g2
        combined_edges[:, 1],  # end_g1
    )


In [12]:
from src.utils.tools import (
    build_conn_edge,
    build_graph,
    edge_affinity_fn,
    node_affinity_fn,
)

ego_graph, cav_graph = build_graph(ego_preds), build_graph(cav_preds)
n1, n2 = torch.tensor([ego_graph.shape[0]]), torch.tensor([cav_graph.shape[0]])
conn1, edge1 = build_conn_edge(ego_graph)
conn2, edge2 = build_conn_edge(cav_graph)
node_mat = node_affinity_fn(ego_preds.unsqueeze(0), cav_preds.unsqueeze(0))[0]
edge_mat = edge_affinity_fn(edge1.unsqueeze(0), edge2.unsqueeze(0))[0]

In [17]:
import torch
import torch.nn as nn


# PyTorch伪代码示例
class NodeMLP(nn.Module):
    def __init__(self):
        super().__init__()
        # 共享权重分支
        self.branch = nn.Sequential(
            nn.Linear(5, 64),  # 输入(x,y,h,w,\alpha)
            nn.ReLU(),
            nn.Linear(64, 32),
        )
        # 相似度计算
        self.distance = nn.CosineSimilarity()

    def forward(self, preds1, preds2):
        x1, x2 = preds1[:, 1:6], preds2[:, 1:6]
        cls1, cls2 = preds1[:, 6], preds2[:, 6]
        conf1, conf2 = preds1[:, 7:], preds2[:, 7:]
        feat1 = self.branch(x1)
        feat2 = self.branch(x2)
        # check if the two nodes are of the same class
        cls_dist = cls1.view(-1, 1) == cls2.view(1, -1)
        # calculate the confidence affinity
        conf_dist = torch.sum(
            torch.sqrt(conf1.unsqueeze(1) * conf2.unsqueeze(0)), dim=-1
        )
        return cls_dist * conf_dist * self.distance(feat1, feat2)


In [18]:
# PyTorch伪代码示例
class EdgeMLP(nn.Module):
    def __init__(self):
        super().__init__()
        # 共享权重分支
        self.branch = nn.Sequential(
            nn.Linear(3, 64),  # 输入(\Delta dist, \Delta \theta, \Delta \alpha)
            nn.ReLU(),
            nn.Linear(64, 32),
        )
        # 相似度计算
        self.distance = nn.CosineSimilarity()

    def forward(self, edges1, edges2):
        x1, x2 = edges1[:, :3], edges2[:, :3]
        feat1 = self.branch(x1)
        feat2 = self.branch(x2)
        # check if the two nodes are of the same class
        cls_edge1, cls_edge2 = edges1[:, 3:].int(), edges2[:, 3:].int()
        def compare_tensors(tensor1, tensor2):
            tensor1_exp = tensor1.unsqueeze(1).expand(-1, tensor2.size(0), -1)
            tensor2_exp = tensor2.unsqueeze(0).expand(tensor1.size(0), -1, -1)
            return torch.eq(tensor1_exp, tensor2_exp).all(dim=-1)

        cls_aff = compare_tensors(cls_edge1, cls_edge2)
        # calculate the confidence affinity

        return cls_aff * self.distance(feat1, feat2)


In [20]:
# PyTorch代码
node_mlp=NodeMLP()
edge_mlp=EdgeMLP()
node_aff_mat = node_mlp(ego_preds, cav_preds)
edge_aff_mat = edge_mlp(edge1, edge2)

In [None]:
K2 = build_affinity_matrix(node_mat, edge_mat, conn1, conn2)
torch.allclose(K, K2)

True

In [21]:
K3=build_affinity_matrix(node_aff_mat, edge_aff_mat, conn1, conn2)
torch.allclose(K, K3)

False