In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
class InteractingLayer(nn.Module):
    def __init__(self, embedding_size, head_num=2, use_res=True, scaling=False):
        super(InteractingLayer, self).__init__()
        self.att_embedding_size = embedding_size // head_num
        self.head_num = head_num
        self.use_res = use_res
        self.scaling = scaling

        self.W_Query = nn.Parameter(torch.Tensor(embedding_size, embedding_size))
        self.W_Key = nn.Parameter(torch.Tensor(embedding_size, embedding_size))
        self.W_Value = nn.Parameter(torch.Tensor(embedding_size, embedding_size))

        if self.use_res:
            self.W_Res = nn.Parameter(torch.Tensor(embedding_size, embedding_size))
        for tensor in self.parameters():
            nn.init.normal_(tensor, mean=0.0, std=0.05)

    def forward(self, inputs):

        # inputs: [1024, 26, 4]
        #keys: [1024, 26, 4]
        querys = torch.tensordot(inputs, self.W_Query, dims=([-1], [0]))
        keys = torch.tensordot(inputs, self.W_Key, dims=([-1], [0]))
        values = torch.tensordot(inputs, self.W_Value, dims=([-1], [0]))

        # keys: [2, 1024, 26, 2]
        querys = torch.stack(torch.split(querys, self.att_embedding_size, dim=2))
        keys = torch.stack(torch.split(keys, self.att_embedding_size, dim=2))
        values = torch.stack(torch.split(values, self.att_embedding_size, dim=2))

        # inner_product: [2, 1024, 26, 26]
        inner_product = torch.einsum('bnik,bnjk->bnij', querys, keys)

        if self.scaling:
            inner_product /= self.att_embedding_size ** 0.5
        self.normalized_att_scores = F.softmax(inner_product, dim=-1)

        # [2, 1024, 26, 2]
        result = torch.matmul(self.normalized_att_scores, values)
        # [1, 1024, 26, 4]
        result = torch.cat(torch.split(result, 1, ), dim=-1)
        # [1024, 26, 4]
        result = torch.squeeze(result, dim=0)
        if self.use_res:
            result += torch.tensordot(inputs, self.W_Res, dims=([-1], [0]))
        result = F.relu(result)
        return result

In [2]:
class GraphAttentionLayer(nn.Module):

    def __init__(self, in_features, out_features, dropout=0.1, alpha=0.5, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat

        # 可学习参数 W，用于线性变换输入特征
        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)

        # 可学习参数 a，用于计算注意力系数
        self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

        # Leaky ReLU 激活函数，用于引入非线性
        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, input, adj):
        h = torch.mm(input, self.W)  # 线性变换，得到节点特征表示，形状为 [N, out_features]
        N = h.size()[0]

        # 构造注意力机制的输入，a_input 的形状为 [N, N, 2*out_features]
        a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)

        # 计算注意力系数，通过学习得到，形状为 [N, N]
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))

        # 使用零向量对不相邻节点的注意力系数进行屏蔽
        zero_vec = -9e15 * torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)

        # 对注意力系数进行 softmax 归一化，然后使用 dropout 进行正则化
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)

        # 使用注意力系数对节点特征进行加权求和，得到更新后的节点表示 h_prime
        h_prime = torch.matmul(attention, h)

        # 如果设置了 concat 为 True，则在更新后的节点表示上使用激活函数 ELU，并返回
        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

In [3]:
layer = GraphAttentionLayer(in_features=10, out_features=6)

In [6]:
# 5*5的邻接矩阵
adjacency_matrix = [
    [0, 1, 1, 0, 0],
    [1, 0, 0, 1, 0],
    [1, 0, 0, 1, 1],
    [0, 1, 1, 0, 1],
    [0, 0, 1, 1, 0]
]

# 转换为PyTorch张量
adj = torch.tensor(adjacency_matrix, dtype=torch.float32)
feat = torch.Tensor(5,10)

In [7]:
layer(feat,adj)

tensor([[-1.0000e+00, -1.0000e+00, -1.0000e+00, -1.0000e+00,  7.8242e+17,
         -1.0000e+00],
        [ 1.7038e-30,  5.0690e-30,  0.0000e+00,  0.0000e+00,  5.4240e-30,
          1.0032e-31],
        [ 6.1067e+19,  1.0260e+20,  1.0073e+20,  1.8604e+20, -1.0000e+00,
         -1.0000e+00],
        [ 6.1067e+19,  1.0260e+20,  1.0073e+20,  1.8604e+20, -1.0000e+00,
         -1.0000e+00],
        [ 2.0356e+19,  3.4199e+19,  3.3577e+19,  6.2012e+19, -1.0000e+00,
         -1.0000e+00]], grad_fn=<EluBackward0>)