In [1]:
import torch
from torch import nn
from hungarian import Hungarian

In [2]:
class CIE(nn.Module):
    """
    Applies a channel-independent update rule to node and edge features.
    """
    def __init__(self, d_in, m_in, d_out, m_out):
        super().__init__()
        self.d_out = d_out
        self.m_out = m_out
        self.W0 = nn.Parameter(torch.empty(size=(d_in, d_out)), requires_grad=True)
        self.W1 = nn.Parameter(torch.empty(size=(m_in, d_out)), requires_grad=True)
        self.W2 = nn.Parameter(torch.empty(size=(d_in, d_out)), requires_grad=True)
        self.W3 = nn.Parameter(torch.empty(size=(m_in, m_out)), requires_grad=True)
        self.W4 = nn.Parameter(torch.empty(size=(d_in, m_out)), requires_grad=True)
        self.W5 = nn.Parameter(torch.empty(size=(m_in, m_out)), requires_grad=True)
       
        nn.init.xavier_normal_(self.W0)
        nn.init.xavier_normal_(self.W1)
        nn.init.xavier_normal_(self.W2)
        nn.init.xavier_normal_(self.W3)
        nn.init.xavier_normal_(self.W4)
        nn.init.xavier_normal_(self.W5)
   
    def forward(self, H, E, verbose=False):
        '''
       H - bs * k * d
       A - bs * k * k
       E - bs * m * k * k
       '''
        batch_size, k, d_in = H.shape
        m_in = E.shape[1]
        d_out, m_out = self.d_out, self.m_out
        
        W1E = E.reshape(batch_size, m_in, k*k).permute(0, 2, 1) @ self.W1  # tensor of shape (batch_size, k*k, d_out)

        W1E = W1E.permute(0, 2, 1).reshape(batch_size, d_out, k, k)  # tensor of shape (batch_size, d, k, k)

        W2H = H @ self.W2

        W0H = H @ self.W0

        W3E = (E.permute(0,2,3,1) @ self.W3).permute(0,3,1,2)

 
        W4H = H @ self.W4

 
        W1E_W2H =  W1E @ W2H.permute(0,2,1).view((batch_size, d_out, k, 1))

        W1E_W2H = W1E_W2H.reshape((batch_size, d_out, k)).permute(0,2,1)

        
       
        H_out = torch.relu(W1E_W2H) + torch.relu(W0H)
        
        W5E = E.reshape(batch_size, m_in, k*k).permute(0, 2, 1) @ self.W5  # tensor of shape (batch_size, k*k, d_out)

        W5E = W5E.permute(0, 2, 1).reshape(batch_size, m_out, k, k)  # tensor of shape (batch_size, d, k, k)
       
        h = W4H.permute(0,2,1).reshape(batch_size, m_out, k, 1) - W4H.permute(0, 2, 1).reshape(batch_size, m_out, 1, k)
        
        hE = torch.abs(h) * W5E

 
        E_out = torch.relu(hE) + torch.relu(W3E)
        
        
        if verbose:
            print('E: ', E)
            print('H: ', H)
            print('W0: ', self.W0)
            print('W1: ', self.W1)
            print('W2: ', self.W2)
            print('W3: ', self.W3)
            print('W4: ', self.W4)
            print('W1E: ', W1E)
            print('W1E: ', W1E)
            print('W2H: ', W2H)
            print('W0H: ', W0H)
            print('W3E: ', W3E)
            print('W4H: ', W4H)
            print('W1E_W2H: ', W1E_W2H)
            print('W1E_W2H: ', W1E_W2H)
            print('H_out: ', H_out)
            print('h: ', h)
            print('hE: ', hE)
            print('E_out: ', E_out)
 
        return H_out, E_out

In [3]:
class SimilarityMatrix(nn.Module):
    """
    Computes a similarity matrix between two graphs based on the node features.
    """
    def __init__(self, d, tau=0.1):
        super().__init__()
        
        arr = torch.empty((d, ))
        arr.uniform_()
        arr = arr
        self.L = nn.Parameter(torch.diag(arr), requires_grad=True)
        self.tau = tau
    
    def forward(self, H1, H2):
        exp = self.tau * H1 @ self.L @ H2.transpose(1, 2)
        max_elements, _ = torch.max(exp, dim=-1, keepdims=True)
        max_elements, _ = torch.max(max_elements, dim=-1, keepdims=True)
        return torch.exp(exp - max_elements)

In [4]:
class Sinkhorn(nn.Module):
    """
    Brings a matrix to a doubly-stochastic form via Sinkhorn algorithm.
    """
    
    def __init__(self, n_iter=20):
        super().__init__()
        self.n_iter = n_iter
        
    def forward(self, M):
        bs, k, _ = M.shape
        ones = torch.ones((k,k))
        
        
        for i in range(self.n_iter):
            M = M - (M @ ones + ones @ M + ones) / k + ones @ M @ ones / k ** 2
        
        return M

In [5]:
class CrossGraphMerging(nn.Module):
    """
    Performs feature merging between two graphs.
    
    input: H1, H2 - node features bs * k * d
           S - similarity matrix of bs * k * k
           
    output: H1_out, H2_out - merged node features 
    """
    
    def __init__(self, d):
        super().__init__()
        
        self.W = nn.Parameter(torch.empty(size=(2*d, d)), requires_grad=True)
        nn.init.xavier_normal_(self.W)
    
    def forward(self, H1, H2, S):
        H1_out = torch.cat([H1, S@H2], dim=-1) @ self.W
        H2_out = torch.cat([H2, S.transpose(-1, -2)@H1], dim=-1) @ self.W
        
        return H1_out, H2_out

In [6]:
def hung_attention(S, S_true):
    '''
    input: S - matching matrix bs*k*k
           S_true - true matching matrix bs*k*k
           
    output: Z - hangarian attention mask bs*k*k which should be elementwise
                multiplied by needed loss function
    '''
    
    #assert S_true.sum() == S_true.shape[0]
    bs, k, _ = S.shape
    Z = torch.zeros_like(S, requires_grad=False)
    
    for i in range(bs):
        hungarian = Hungarian(S[i,:,:].detach(), is_profit_matrix=True)
        hungarian.calculate()

        idx = torch.tensor(hungarian.get_results())

        Z_buf = torch.zeros((k,k), requires_grad=False)
        Z_buf[idx[:,0],idx[:,1]] = 1.

        Z_buf = Z_buf.long() | S_true[i,:,:].long()
        
        Z[i,:,:] = Z_buf.double()
    
    return Z


class PermutationLoss(nn.Module):
    def __init__(self, hung_attention=False):
        super().__init__()
        self.hung_attention = hung_attention
        
    def forward(self, S, S_true):
        if self.hung_attention:
            Z = hung_attention(S, S_true)
        else:
            Z = torch.ones(S.shape)
        
        loss = -torch.sum(Z * (S_true * torch.log(S) + (1. - S_true) * torch.log(1. - S)))
        
        return loss

In [7]:
class HungarianModel(nn.Module):
    def __init__(self, d_in, m_in, d_out, m_out):
        super().__init__()
        
        self.cie1 = CIE(d_in, m_in, d_out, m_out)
        self.cie2 = CIE(d_out, m_out, d_out, m_out)
        self.sim = SimilarityMatrix(d_out)
        self.cross = CrossGraphMerging(d_out)
        self.sinkhorn = Sinkhorn(n_iter=10)
        
    def forward(self, H1, E1, H2, E2):
        
        H1, E1 = self.cie1(H1, E1)
        H2, E2 = self.cie1(H2, E2)
#         print(f'\n\nafter cie1 \n\nH1: {H1} \n\n\nE1: {E1}')
        S = self.sinkhorn(self.sim(H1, H2))
#         print(f'\n\nafter sink1 \n\nS: {S}')
        H1, H2 = self.cross(H1, H2, S)
#         print(f'\n\nafter cross \n\nH1: {H1}')
        H1, E1 = self.cie2(H1, E1)
#         print(f'\n\nafter cie2 \n\nH1: {H1} \n\n\nE1: {E1}')
        H2, E2 = self.cie2(H2, E2)
#         print(f'\n\sim \n\nS: {self.sim(H1, H2).sum()}')
        S = self.sinkhorn(self.sim(H1, H2))
#         print(f'\n\nafter sink2 \n\nS: {S}')
        
        return S