In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:
k = 120
V_1 = torch.randn(size=(1, k))
V_2 = torch.randn(size=(1, k))
V_3 = torch.randn(size=(1, k))

## First layer

In [3]:
class modal_attention_network(nn.Module):
    def __init__(self, input_dim, output_dim, dropout):
        super().__init__()
        self.fc = nn.Linear(input_dim, output_dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.fc(x)
        x = self.sigmoid(x)
        return x


class FirstLayer(nn.Module):
    def __init__(self, input_dim, dropout):
        super().__init__()
        self.modal_attention_network = modal_attention_network(input_dim, 1, dropout)

    def forward(self, x_1, x_2, x_3):
        a_1 = self.modal_attention_network(x_1)
        a_2 = self.modal_attention_network(x_2)
        a_3 = self.modal_attention_network(x_3)
        U = 1 / 3 * torch.sum(torch.stack([a_1 * x_1, a_2 * x_2, a_3 * x_3], dim=1), dim=1)
        return U, a_1, a_2, a_3

In [4]:
first_layer = FirstLayer(k, 0.5)
U, a_1, a_2, a_3 = first_layer(V_1, V_2, V_3)
U.shape

torch.Size([1, 120])

In [5]:
class MultiLayerNeuralFusionNetwork(nn.Module):
    '''
    :input: 两种模态向量
    :output: 双模态节点
    '''

    def __init__(self, input_dim):
        super().__init__()
        self.fusion_layer = nn.Linear(2 * input_dim, input_dim)

    def forward(self, x_1, x_2):
        return self.fusion_layer(torch.cat([x_1, x_2], dim=1))

In [6]:
multilayer_neural_fusion_network = MultiLayerNeuralFusionNetwork(k)
multilayer_neural_fusion_network(V_1, V_2).shape

torch.Size([1, 120])

In [7]:
class SecondLayer(nn.Module):
    def __init__(self, k: int = 120):
        super().__init__()
        self.multilayer_neural_fusion_network = MultiLayerNeuralFusionNetwork(k)

    def forward(self, x_1, x_2, x_3, a_1, a_2, a_3):
        V_12 = self.multilayer_neural_fusion_network(x_1, x_2)
        V_13 = self.multilayer_neural_fusion_network(x_1, x_3)
        V_23 = self.multilayer_neural_fusion_network(x_2, x_3)

        S_12 = x_1 @ x_2.T
        S_13 = x_1 @ x_3.T
        S_23 = x_2 @ x_3.T

        a_12_hat = (a_1 + a_2) / (S_12 + 0.5)
        a_13_hat = (a_1 + a_3) / (S_13 + 0.5)
        a_23_hat = (a_2 + a_3) / (S_23 + 0.5)

        a_12 = torch.exp(a_12_hat) / (torch.exp(a_13_hat) + torch.exp(a_23_hat))
        a_13 = torch.exp(a_13_hat) / (torch.exp(a_12_hat) + torch.exp(a_23_hat))
        a_23 = torch.exp(a_23_hat) / (torch.exp(a_12_hat) + torch.exp(a_13_hat))

        B = torch.sum(torch.stack([a_12 * V_12, a_13 * V_13, a_23 * V_23], dim=1), dim=1)
        return B, a_12, a_13, a_23, V_12, V_13, V_23

In [8]:
second_layer = SecondLayer()
B, a_12, a_13, a_23, V_12, V_13, V_23 = second_layer(V_1, V_2, V_3, a_1, a_2, a_3)

In [9]:
B.shape

torch.Size([1, 120])

In [10]:
class fusion_layer_for_thirdmodal(nn.Module):
    def __init__(self, k: int = 120):
        super().__init__()
        self.multilayer_neural_fusion_network = MultiLayerNeuralFusionNetwork(k)

    def forward(self, V_1, V_23, V_2, V_13, V_3, V_12, a_1, a_23, a_2, a_13, a_3, a_12):
        V_1_23 = self.multilayer_neural_fusion_network(V_1, V_23)
        V_2_13 = self.multilayer_neural_fusion_network(V_2, V_13)
        V_3_12 = self.multilayer_neural_fusion_network(V_3, V_12)

        S_1_23 = V_1 @ V_23.T
        S_2_13 = V_2 @ V_13.T
        S_3_12 = V_3 @ V_12.T

        a_1_23_hat = (a_1 + a_23) / (S_1_23 + 0.5)
        a_2_13_hat = (a_2 + a_13) / (S_2_13 + 0.5)
        a_3_12_hat = (a_3 + a_12) / (S_3_12 + 0.5)
        a_1_23 = torch.exp(a_1_23_hat) / (torch.exp(a_2_13_hat) + torch.exp(a_3_12_hat))
        a_2_13 = torch.exp(a_2_13_hat) / (torch.exp(a_1_23_hat) + torch.exp(a_3_12_hat))
        a_3_12 = torch.exp(a_3_12_hat) / (torch.exp(a_1_23_hat) + torch.exp(a_2_13_hat))

        return V_1_23, V_2_13, V_3_12, a_1_23, a_2_13, a_3_12



In [11]:
class ThirdLayer(nn.Module):
    def __init__(self, k: int = 120):
        super().__init__()
        self.fusion_module_1 = SecondLayer(k)
        self.fusion_module_2 = fusion_layer_for_thirdmodal(k)

    def forward(self, V_1, V_2, V_3, V_12, V_13, V_23, a_1, a_2, a_3, a_12, a_13, a_23):
        _, a_1213, a_1223, a_1323, V_1213, V_1223, V_1323 = self.fusion_module_1(V_12, V_13, V_23, a_12, a_13, a_23)
        V_1_23, V_2_13, V_3_12, a_1_23, a_2_13, a_3_12 = self.fusion_module_2(V_1, V_23, V_2, V_13, V_3, V_12, a_1,
                                                                              a_23, a_2, a_13, a_3, a_12)
        O = torch.sum(torch.stack(
            [a_1_23 * V_1_23, a_2_13 * V_2_13, a_3_12 * V_3_12, a_1213 * V_1213, a_1223 * V_1223, a_1323 * V_1323],
            dim=1), dim=1)
        return O

In [12]:
third_layer = ThirdLayer()
O = third_layer(V_1, V_2, V_3, V_12, V_13, V_23, a_1, a_2, a_3, a_12, a_13, a_23)
O.shape

torch.Size([1, 120])

In [13]:
O

tensor([[ 0.6015,  0.0893,  0.2374, -0.3873,  0.3150,  0.0814,  1.3210, -0.6099,
         -1.3153, -0.3951,  0.2140,  0.0826,  0.1556,  0.9649,  0.1026,  1.3271,
         -0.0451, -0.4745,  0.0339, -0.7251,  0.0915, -0.5068,  0.5691,  0.5645,
         -0.9825, -0.1342, -0.2583, -0.4383, -1.0022,  0.4898, -0.1810, -0.1723,
          0.3781, -0.6662, -1.2945, -1.5308, -0.1813, -0.2842, -0.9219,  0.3368,
         -1.3098, -0.4573, -0.8250,  1.3964, -1.1801, -0.2033,  0.9082, -0.1288,
          0.1234,  1.0271, -0.5290, -0.8767,  0.4783,  0.8033, -0.4152, -0.5281,
         -0.7501, -0.3347,  0.4673, -0.4432, -0.4431, -0.0732,  0.4266, -0.1887,
          0.1840,  1.1373,  0.6127, -0.0106,  0.7999, -0.5460, -0.0143,  0.2430,
         -0.2398, -0.2162,  0.1107,  0.8499, -0.2635,  0.5386,  0.3964,  0.8227,
          2.0436,  0.9463, -0.9779, -0.1312,  0.3160,  0.0258,  2.0368, -0.8336,
          0.2256,  1.0543,  0.2935, -0.0204, -0.4197,  0.3797,  0.0190, -1.2091,
         -0.7152,  0.7517, -