In [15]:
import torch
from torch import nn
from hungarian import Hungarian

In [3]:
class CIE(nn.Module):
    """
    Applies a channel-independent update rule to node and edge features.
    """
    def __init__(self, d, m):
        super().__init__()
       
        self.W0 = nn.Parameter(torch.empty(size=(d, d)), requires_grad=True)
        self.W1 = nn.Parameter(torch.empty(size=(d, m)), requires_grad=True)
        self.W2 = nn.Parameter(torch.empty(size=(d, d)), requires_grad=True)
        self.W3 = nn.Parameter(torch.empty(size=(m, m)), requires_grad=True)
        self.W4 = nn.Parameter(torch.empty(size=(d, m)), requires_grad=True)
       
        nn.init.xavier_normal_(self.W0)
        nn.init.xavier_normal_(self.W1)
        nn.init.xavier_normal_(self.W2)
        nn.init.xavier_normal_(self.W3)
   
    def forward(self, H, A, E):
        '''
       H - bs * k * d
       A - bs * k * k
       E - bs * m * k * k
       '''
        batch_size, k, d = H.shape
        m = E.shape[1]
 
        W1E = E.reshape(batch_size, m, k*k).permute(0, 2, 1) @ self.W1.T  # tensor of shape (batch_size, k*k, d)
        W1E = W1E.permute(0, 2, 1).reshape(batch_size, d, k, k)  # tensor of shape (batch_size, d, k, k)
        W2H = H @ self.W2
        W0H = H @ self.W0
        W3E = (E.permute(0,2,3,1) @ self.W3).permute(0,3,1,2)
 
        W4H = H @ self.W4
 
        W1E_W2H = (A.view(bs,1,k,k) * W1E) @ W2H.permute(0,2,1).view((batch_size,d,k,1))
        W1E_W2H = W1E_W2H.reshape((batch_size,d,k)).permute(0,2,1)
       
        H_out = torch.relu(W1E_W2H) + torch.relu(W0H)
       
        h = W4H.permute(0,2,1).view(bs,m,k,1) - W4H.permute(0,2,1).view(bs,m,1,k)
        hE = torch.abs(h) * E
 
        E_out = torch.relu(hE) + torch.relu(W3E)
 
        return H_out, E_out

In [13]:
class SimilarityMatrix(nn.Module):
    """
    Computes a similarity matrix between two graphs based on the node features.
    """
    def __init__(self, d, tau=0.1):
        super().__init__()
        
        arr = torch.empty((d, ))
        arr.uniform_()
        
        self.L = nn.Parameter(torch.diag(arr), requires_grad=True)
        self.tau = tau
    
    def forward(self, H1, H2):
        
        return torch.exp(H1 @ self.L @ H2.T)

In [None]:
class Sinkhorn(nn.Module):
    """
    Brings a matrix to a doubly-stochastic form via Sinkhorn algorithm.
    """
    
    def __init__(self, n_iter=20):
        super().__init__()
        self.n_iter = n_iter
        
    def forward(self, M):
        bs, k, _ = M.shape
        one = torch.ones((k,1))
        
        for i in range(self.n_iter):
            norm_row = M @ one @ one.T
            M = M / norm_row

            M = M / (one @ one.T @ M)
        
        return M

In [4]:
class CrossGraphMerging(nn.Module):
    """
    Performs feature merging between two graphs.
    
    input: H1, H2 - node features bs * k * d
           S - similarity matrix of bs * k * k
           
    output: H1_out, H2_out - merged node features 
    """
    
    def __init__(self, d):
        super().__init__()
        
        self.W = nn.Parameter(torch.empty(size=(2*d, d)), requires_grad=True)
    
    def forward(self, H1, H2, S):
        H1_out = torch.cat([H1, S@H2]) @ self.W
        H2_out = torch.cat([H2, S.T@H1]) @ self.W
        
        return H1_out, H2_out

In [None]:
def hung_attention(S, S_true):
    '''
    input: S - matching matrix bs * k * k
           S_true - true matching matrix bs * k * k
           
    output: Z - hungarian attention mask bs*k*k which should be elementwise
                multiplied by needed loss function
    '''
    
    bs = S.shape[0]
    Z = torch.zeros_like(S, requires_grad=False)
    
    for i in range(bs):
        hungarian = Hungarian(S[i], is_profit_matrix=True)
        hungarian.calculate()

        idx = torch.tensor(hungarian.get_results())

        Z_buf = torch.zeros(S.shape, requires_grad=False)
        Z_buf[idx[:,0],idx[:,1]] = 1.

        Z_buf = Z_buf.long() | S_true.long()
        
        Z[i] = Z_buf.double()
    
    return Z.double()


class HungarianLoss(nn.Module):
    def __init__(self, hung_attention=False):
        super().__init__()
        self.hung_attention = hung_attention
        
    def forward(self, S, S_true):
        if self.hung_attention:
            Z = hung_attention(S, S_true)
        else:
            Z = torch.ones(S.shape)
        
        loss = -torch.sum(Z * (S_true * torch.log(S) + (1. - S_true) * torch.log(1. - S)))
        
        return loss

In [14]:
class HungarianModel(nn.Module):
    def __init(self, d, m):
        super().__init__()
        
        self.cie1 = CIE(d, m)
        self.cie2 = CIE(d, m)
        self.sim = SimilarityMatrix(d)
        self.cross = CrossGraphMerging(d)
        self.sinkhorn = Sinkhorn(n_iter=10)
        
    def forward(self, H1, E1, H2, E2):
        
        H1, E1 = self.cie1(H1, E1)
        H2, E2 = self.cie1(H2, E2)
        S = self.sinkhorn(self.sim(H1, H2))
        H1, H2 = self.cross(H1, H2, S)
        H1, E1 = self.cie2(H1, E1)
        H2, E2 = self.cie2(H2, E2)
        S = self.sinkhorn(self.sim(H1, H2))
        
        return S