In [30]:
import torch
from torch import nn
from hungarian import Hungarian

In [225]:
class CIE(nn.Module):
    """
    Applies a channel-independent update rule to node and edge features.
    """
    def __init__(self, d, m):
        super().__init__()
       
        self.W0 = nn.Parameter(torch.empty(size=(d, d)), requires_grad=True)
        self.W1 = nn.Parameter(torch.empty(size=(m, d)), requires_grad=True)
        self.W2 = nn.Parameter(torch.empty(size=(d, d)), requires_grad=True)
        self.W3 = nn.Parameter(torch.empty(size=(m, m)), requires_grad=True)
        self.W4 = nn.Parameter(torch.empty(size=(d, m)), requires_grad=True)
       
        nn.init.xavier_normal_(self.W0)
        nn.init.xavier_normal_(self.W1)
        nn.init.xavier_normal_(self.W2)
        nn.init.xavier_normal_(self.W3)
        nn.init.xavier_normal_(self.W4)
   
    def forward(self, H, E, verbose=False):
        '''
       H - bs * k * d
       A - bs * k * k
       E - bs * m * k * k
       '''
        batch_size, k, d = H.shape
        m = E.shape[1]

        W1E = E.reshape(batch_size, m, k*k).permute(0, 2, 1) @ self.W1  # tensor of shape (batch_size, k*k, d)

        W1E = W1E.permute(0, 2, 1).reshape(batch_size, d, k, k)  # tensor of shape (batch_size, d, k, k)

        W2H = H @ self.W2

        W0H = H @ self.W0

        W3E = (E.permute(0,2,3,1) @ self.W3).permute(0,3,1,2)

 
        W4H = H @ self.W4

 
        W1E_W2H =  W1E @ W2H.permute(0,2,1).view((batch_size,d,k,1))

        W1E_W2H = W1E_W2H.reshape((batch_size,d,k)).permute(0,2,1)

        
       
        H_out = torch.relu(W1E_W2H) + torch.relu(W0H)

       
        h = W4H.permute(0,2,1).reshape(batch_size, m, k, 1) - W4H.permute(0, 2, 1).reshape(batch_size, m, 1, k)

        hE = torch.abs(h) * E

 
        E_out = torch.relu(hE) + torch.relu(W3E)
        
        
        if verbose:
            print('E: ', E)
            print('H: ', H)
            print('W1E: ', W1E)
            print('W1E: ', W1E)
            print('W2H: ', W2H)
            print('W0H: ', W0H)
            print('W3E: ', W3E)
            print('W4H: ', W4H)
            print('W1E_W2H: ', W1E_W2H)
            print('W1E_W2H: ', W1E_W2H)
            print('H_out: ', H_out)
            print('h: ', h)
            print('hE: ', hE)
            print('E_out: ', E_out)
 
        return H_out, E_out

In [232]:
import math

In [239]:
class SimilarityMatrix(nn.Module):
    """
    Computes a similarity matrix between two graphs based on the node features.
    """
    def __init__(self, d, tau=0.001):
        super().__init__()
        
        arr = torch.empty((d, ))
        arr.uniform_()
        arr = arr / math.sqrt(d)
        self.L = nn.Parameter(torch.diag(arr), requires_grad=True)
        self.tau = tau
    
    def forward(self, H1, H2):
        
        return torch.exp(self.tau * H1 @ self.L @ H2.transpose(1, 2))

In [33]:
class Sinkhorn(nn.Module):
    """
    Brings a matrix to a doubly-stochastic form via Sinkhorn algorithm.
    """
    
    def __init__(self, n_iter=20):
        super().__init__()
        self.n_iter = n_iter
        
    def forward(self, M):
        bs, k, _ = M.shape
        one = torch.ones((k,1))
        
        for i in range(self.n_iter):
            norm_row = M @ one @ one.T
            M = M / norm_row

            M = M / (one @ one.T @ M)
        
        return M

In [231]:
class CrossGraphMerging(nn.Module):
    """
    Performs feature merging between two graphs.
    
    input: H1, H2 - node features bs * k * d
           S - similarity matrix of bs * k * k
           
    output: H1_out, H2_out - merged node features 
    """
    
    def __init__(self, d):
        super().__init__()
        
        self.W = nn.Parameter(torch.empty(size=(2*d, d)), requires_grad=True)
        nn.init.xavier_normal_(self.W)
    
    def forward(self, H1, H2, S):
        H1_out = torch.cat([H1, S@H2], dim=-1) @ self.W
        H2_out = torch.cat([H2, S.transpose(-1, -2)@H1], dim=-1) @ self.W
        
        return H1_out, H2_out

In [35]:
def hung_attention(S, S_true):
    '''
    input: S - matching matrix bs * k * k
           S_true - true matching matrix bs * k * k
           
    output: Z - hungarian attention mask bs*k*k which should be elementwise
                multiplied by needed loss function
    '''
    
    bs = S.shape[0]
    Z = torch.zeros_like(S, requires_grad=False)
    
    for i in range(bs):
        hungarian = Hungarian(S[i], is_profit_matrix=True)
        hungarian.calculate()

        idx = torch.tensor(hungarian.get_results())

        Z_buf = torch.zeros(S.shape, requires_grad=False)
        Z_buf[idx[:,0],idx[:,1]] = 1.

        Z_buf = Z_buf.long() | S_true.long()
        
        Z[i] = Z_buf.double()
    
    return Z.double()


class HungarianLoss(nn.Module):
    def __init__(self, hung_attention=False):
        super().__init__()
        self.hung_attention = hung_attention
        
    def forward(self, S, S_true):
        if self.hung_attention:
            Z = hung_attention(S, S_true)
        else:
            Z = torch.ones(S.shape)
        
        loss = -torch.sum(Z * (S_true * torch.log(S) + (1. - S_true) * torch.log(1. - S)))
        
        return loss

In [248]:
class HungarianModel(nn.Module):
    def __init__(self, d, m):
        super().__init__()
        
        self.cie1 = CIE(d, m)
        self.cie2 = CIE(d, m)
        self.sim = SimilarityMatrix(d)
        self.cross = CrossGraphMerging(d)
        self.sinkhorn = Sinkhorn(n_iter=10)
        
    def forward(self, H1, E1, H2, E2):
        
        H1, E1 = self.cie1(H1, E1)
        H2, E2 = self.cie1(H2, E2)
#         print(f'\n\nafter cie1 \n\nH1: {H1} \n\n\nE1: {E1}')
        S = self.sinkhorn(self.sim(H1, H2))
#         print(f'\n\nafter sink1 \n\nS: {S}')
        H1, H2 = self.cross(H1, H2, S)
#         print(f'\n\nafter cross \n\nH1: {H1}')
        H1, E1 = self.cie2(H1, E1)
#         print(f'\n\nafter cie2 \n\nH1: {H1} \n\n\nE1: {E1}')
        H2, E2 = self.cie2(H2, E2)
#         print(f'\n\sim \n\nS: {self.sim(H1, H2)}')
        S = self.sinkhorn(self.sim(H1, H2))
#         print(f'\n\nafter sink2 \n\nS: {S}')
        
        return S

In [210]:
H1 = torch.empty((50, 10))
H1.normal_()
noise = torch.empty_like(H1)
noise.uniform_()
H2 = H1.clone() + 0.1 * noise

In [211]:
E1 = torch.empty((10, 50, 50))
E1.normal_()
noise = torch.empty_like(E1)
noise.uniform_()
E2 = E1.clone() + 0.1 * noise

In [212]:
H1 = H1.unsqueeze(0)
H2 = H2.unsqueeze(0)
E1 = E1.unsqueeze(0)
E2 = E2.unsqueeze(0)

In [223]:
print(f'H1: {H1} \n\n\nH2: {H2}')

H1: tensor([[[ 1.8214e+00,  4.9190e-01,  1.8101e+00, -4.5739e-01,  1.2204e+00,
           6.9183e-01, -6.3177e-01,  1.8867e+00, -1.1944e+00, -1.0547e+00],
         [-6.2148e-01, -2.1203e-02, -9.0768e-01,  1.2021e+00, -1.2980e+00,
          -2.6775e-01,  1.8330e+00, -5.6803e-01,  8.3794e-01,  4.4868e-01],
         [ 8.3348e-01, -3.6954e-02,  2.1442e-01,  6.6383e-01, -5.9640e-01,
           5.9061e-01,  2.2852e+00, -5.1136e-01, -3.0587e-01,  1.2544e+00],
         [ 1.0891e+00,  2.3967e-01, -8.2227e-01, -6.2664e-01,  6.2939e-01,
           1.2492e+00, -1.3993e-01,  1.9041e-01, -4.6548e-01, -5.3445e-01],
         [-7.2180e-01, -8.5613e-01, -5.9874e-01,  2.5292e-01, -2.3819e-01,
          -1.5233e+00,  9.3607e-01,  8.3893e-01, -1.8751e+00, -1.1651e+00],
         [ 5.4275e-01, -1.9178e+00, -2.7277e-02,  2.8580e-03, -1.2465e+00,
           1.8537e+00,  1.1574e+00,  2.5557e-01,  5.0530e-01, -5.3731e-01],
         [ 1.2561e+00, -1.6765e+00, -3.3301e-01,  1.0789e+00,  3.1896e-01,
           1.33

In [224]:
print(f'E1: {E1} \n\n\nE2: {E2}')

E1: tensor([[[[ 0.1064,  0.2818,  0.5460,  ...,  0.8734,  0.6053,  0.8511],
          [ 0.1636,  0.9537, -0.1779,  ...,  0.6933,  1.4608,  0.6824],
          [-0.7211, -0.7937, -1.4147,  ..., -0.0622,  0.9699, -0.2518],
          ...,
          [ 0.2215, -0.8022,  0.1912,  ..., -1.2008, -0.7705,  0.4311],
          [-0.0521,  1.1293, -2.5418,  ..., -0.3503, -0.0194, -0.3378],
          [ 0.8829,  1.0577,  0.1615,  ..., -1.2016,  0.8788, -0.1574]],

         [[-0.1203, -1.3968,  0.4034,  ...,  0.0666, -0.3140,  0.6169],
          [-0.9357,  0.9352, -0.7266,  ...,  0.5250,  1.2686,  1.4364],
          [-1.3384, -0.0115, -1.0877,  ...,  0.1729, -0.2196, -0.9759],
          ...,
          [-0.5525,  0.9364, -1.0082,  ..., -0.7715, -0.7458, -0.3325],
          [ 0.6077, -0.8702, -0.2455,  ...,  1.0610, -0.4850, -0.4298],
          [-0.3415, -0.7961,  0.9863,  ...,  0.6579,  0.6719, -1.4154]],

         [[-1.3228, -1.6875, -0.6982,  ...,  0.9604,  1.3964, -0.6362],
          [-0.3153, -1.029

In [221]:
cie_test = CIE(10, 10)

In [216]:
cie_test(H2, E2)

(tensor([[[0.0000e+00, 9.1536e+00, 1.2537e+01, 2.7615e+00, 2.3277e+00,
           7.8461e-01, 1.4859e+00, 2.0721e+01, 4.3435e-01, 4.6687e-01],
          [1.3748e+01, 1.9160e+00, 0.0000e+00, 0.0000e+00, 1.7907e+00,
           7.4241e+00, 6.7256e-01, 5.7198e-01, 3.5195e+00, 0.0000e+00],
          [0.0000e+00, 3.1784e+00, 1.2582e+01, 7.0438e-01, 3.9357e+00,
           1.0989e+00, 3.9303e-01, 8.7299e+00, 1.0946e-01, 0.0000e+00],
          [3.1985e-01, 3.7842e-01, 3.6539e+00, 9.0588e-01, 2.9880e+00,
           7.4322e-01, 4.5032e+00, 0.0000e+00, 1.5270e+00, 0.0000e+00],
          [9.1023e-01, 6.1173e+00, 6.1411e+00, 1.5132e+00, 9.2061e-01,
           2.6740e+00, 0.0000e+00, 0.0000e+00, 8.4200e-01, 2.7144e+00],
          [2.7259e+00, 4.5206e+00, 0.0000e+00, 1.2678e-01, 6.6965e+00,
           3.3567e+00, 5.0386e+00, 1.2241e+01, 1.0337e+00, 0.0000e+00],
          [8.4417e-01, 2.8645e+00, 0.0000e+00, 1.2763e+00, 2.5647e+00,
           1.7843e+01, 6.7591e+00, 1.0579e-01, 1.0036e+00, 0.0000e+00],

In [222]:
cie_test(H1, E1, verbose=True)

E:  tensor([[[[ 0.1064,  0.2818,  0.5460,  ...,  0.8734,  0.6053,  0.8511],
          [ 0.1636,  0.9537, -0.1779,  ...,  0.6933,  1.4608,  0.6824],
          [-0.7211, -0.7937, -1.4147,  ..., -0.0622,  0.9699, -0.2518],
          ...,
          [ 0.2215, -0.8022,  0.1912,  ..., -1.2008, -0.7705,  0.4311],
          [-0.0521,  1.1293, -2.5418,  ..., -0.3503, -0.0194, -0.3378],
          [ 0.8829,  1.0577,  0.1615,  ..., -1.2016,  0.8788, -0.1574]],

         [[-0.1203, -1.3968,  0.4034,  ...,  0.0666, -0.3140,  0.6169],
          [-0.9357,  0.9352, -0.7266,  ...,  0.5250,  1.2686,  1.4364],
          [-1.3384, -0.0115, -1.0877,  ...,  0.1729, -0.2196, -0.9759],
          ...,
          [-0.5525,  0.9364, -1.0082,  ..., -0.7715, -0.7458, -0.3325],
          [ 0.6077, -0.8702, -0.2455,  ...,  1.0610, -0.4850, -0.4298],
          [-0.3415, -0.7961,  0.9863,  ...,  0.6579,  0.6719, -1.4154]],

         [[-1.3228, -1.6875, -0.6982,  ...,  0.9604,  1.3964, -0.6362],
          [-0.3153, -1.029

(tensor([[[0.0000e+00, 5.9381e+00, 1.0947e+01, 2.4363e+00, 0.0000e+00,
           9.3657e-02, 3.9877e+00, 1.1727e+01, 3.1644e+00, 1.9336e+00],
          [2.7320e+00, 9.2836e+00, 1.1063e-01, 2.4617e-01, 0.0000e+00,
           5.5044e+00, 1.3759e+00, 0.0000e+00, 1.7231e+00, 8.8405e-01],
          [4.7749e-01, 3.1796e+00, 1.4592e+01, 3.8977e+00, 2.3378e-01,
           5.4246e+00, 4.2256e-01, 0.0000e+00, 3.6746e+00, 0.0000e+00],
          [4.4370e+00, 5.9165e-01, 4.0668e-01, 5.0192e+00, 0.0000e+00,
           0.0000e+00, 0.0000e+00, 1.0921e+00, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00, 2.2520e+00, 6.1525e+00,
           0.0000e+00, 1.2066e-01, 0.0000e+00, 8.7882e+00, 0.0000e+00],
          [3.1123e+00, 0.0000e+00, 3.3889e+00, 4.6983e-01, 0.0000e+00,
           0.0000e+00, 9.1421e-01, 2.3593e+00, 7.3791e-01, 1.2709e+01],
          [1.4214e+01, 2.4259e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
           3.1677e+00, 4.2160e+00, 2.3751e+00, 0.0000e+00, 0.0000e+00],

In [250]:
model = HungarianModel(10, 10)

In [251]:
S_pred = model(H1, E1, H2, E2)

In [252]:
S_pred

tensor([[[0.1104, 0.0269, 0.0284,  ..., 0.0501, 0.0121, 0.0074],
         [0.0257, 0.0257, 0.0225,  ..., 0.0230, 0.0195, 0.0184],
         [0.0254, 0.0235, 0.0305,  ..., 0.0318, 0.0166, 0.0152],
         ...,
         [0.0471, 0.0244, 0.0338,  ..., 0.0471, 0.0127, 0.0091],
         [0.0119, 0.0188, 0.0155,  ..., 0.0120, 0.0258, 0.0271],
         [0.0073, 0.0179, 0.0132,  ..., 0.0084, 0.0265, 0.0424]]],
       grad_fn=<DivBackward0>)

In [254]:
criterion = HungarianLoss(hung_attention=hung_attention)

In [257]:
criterion(S_pred.detach(), torch.ones_like(S_pred).detach())

IndexError: index 3 is out of bounds for dimension 0 with size 1