In [1]:
import math,os
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch_sparse import SparseTensor, matmul
from torch_geometric.utils import degree
from einops import rearrange, repeat

In [73]:
def full_attention_conv(qs, ks, vs, kernel, output_attn=False):
    '''
    qs: query tensor [N, H, M]
    ks: key tensor [L, H, M]
    vs: value tensor [L, H, D]

    return output [N, H, D]
    '''
    if kernel == 'simple':
        # normalize input
        qs = qs / torch.norm(qs, p=2) # [N, H, M]
        ks = ks / torch.norm(ks, p=2) # [L, H, M]
        N = qs.shape[0]

        # numerator
        kvs = torch.einsum("lhm,lhd->hmd", ks, vs)
        attention_num = torch.einsum("nhm,hmd->nhd", qs, kvs) # [N, H, D]
        all_ones = torch.ones([vs.shape[0]]).to(vs.device)
        vs_sum = torch.einsum("l,lhd->hd", all_ones, vs) # [H, D]
        attention_num += vs_sum.unsqueeze(0).repeat(vs.shape[0], 1, 1) # [N, H, D]

        # denominator
        all_ones = torch.ones([ks.shape[0]]).to(ks.device)
        ks_sum = torch.einsum("lhm,l->hm", ks, all_ones)
        attention_normalizer = torch.einsum("nhm,hm->nh", qs, ks_sum)  # [N, H]

        # attentive aggregated results
        attention_normalizer = torch.unsqueeze(attention_normalizer, len(attention_normalizer.shape))  # [N, H, 1]
        attention_normalizer += torch.ones_like(attention_normalizer) * N
        attn_output = attention_num / attention_normalizer # [N, H, D]

        # compute attention for visualization if needed
        if output_attn:
            attention = torch.einsum("nhm,lhm->nlh", qs, ks) / attention_normalizer # [N, L, H]

    elif kernel == 'sigmoid':
        # numerator
        attention_num = torch.sigmoid(torch.einsum("nhm,lhm->nlh", qs, ks))  # [N, L, H]

        # denominator
        all_ones = torch.ones([ks.shape[0]]).to(ks.device)
        attention_normalizer = torch.einsum("nlh,l->nh", attention_num, all_ones)
        attention_normalizer = attention_normalizer.unsqueeze(1).repeat(1, ks.shape[0], 1)  # [N, L, H]

        # compute attention and attentive aggregated results
        attention = attention_num / attention_normalizer
        attn_output = torch.einsum("nlh,lhd->nhd", attention, vs)  # [N, H, D]

    if output_attn:
        return attn_output, attention
    else:
        return attn_output

def gcn_conv(x, edge_index, edge_weight):
    N, H = x.shape[0], x.shape[1]
    # print(N)
    print("edge_index",edge_index.shape)
    row, col = edge_index
    # print(row)
    d = degree(col, N).float()
    d_norm_in = (1. / d[col]).sqrt()
    d_norm_out = (1. / d[row]).sqrt()
    
    gcn_conv_output = []
    
    if edge_weight is None:
        value = torch.ones_like(row) * d_norm_in * d_norm_out
        
    else:
        value = edge_weight * d_norm_in * d_norm_out
    value = torch.nan_to_num(value, nan=0.0, posinf=0.0, neginf=0.0)
   
    adj = SparseTensor(row=col, col=row, value=value, sparse_sizes=(N, N))
    # print(adj)
    # adj.to('cpu')
    for i in range(x.shape[1]):
        gcn_conv_output.append(matmul(adj, x[:, i]) )  # [N, D]
        # print("here1")
    gcn_conv_output = torch.stack(gcn_conv_output, dim=1) # [N, H, D]
    
    return gcn_conv_output

class DIFFormerConv(nn.Module):
    '''
    one DIFFormer layer
    '''
    def __init__(self, in_channels,
               out_channels,
               num_heads,
               kernel='simple',
               use_graph=True,
               use_weight=True):
        super(DIFFormerConv, self).__init__()
        self.Wk = nn.Linear(in_channels, out_channels * num_heads)
        self.Wq = nn.Linear(in_channels, out_channels * num_heads)
        if use_weight:
            self.Wv = nn.Linear(in_channels, out_channels * num_heads)

        self.out_channels = out_channels
        self.num_heads = num_heads
        self.kernel = kernel
        self.use_graph = use_graph
        self.use_weight = use_weight

    def reset_parameters(self):
        self.Wk.reset_parameters()
        self.Wq.reset_parameters()
        if self.use_weight:
            self.Wv.reset_parameters()

    def forward(self, query_input, source_input, edge_index=None, edge_weight=None, output_attn=False):
        # feature transformation
        query = self.Wq(query_input).reshape(-1, self.num_heads, self.out_channels)
        key = self.Wk(source_input).reshape(-1, self.num_heads, self.out_channels)
        if self.use_weight:
            value = self.Wv(source_input).reshape(-1, self.num_heads, self.out_channels)
        else:
            value = source_input.reshape(-1, 1, self.out_channels)

        # compute full attentive aggregation
        if output_attn:
            attention_output, attn = full_attention_conv(query, key, value, self.kernel, output_attn)  # [N, H, D]
        else:
            attention_output = full_attention_conv(query,key,value,self.kernel) # [N, H, D]

        # use input graph for gcn conv
        if self.use_graph:
            final_output = attention_output + gcn_conv(value, edge_index, edge_weight)
        else:
            final_output = attention_output
        final_output = final_output.mean(dim=1)

        if output_attn:
            return final_output, attn
        else:
            return final_output

class DIFFormer(nn.Module):
    '''
    DIFFormer model class
    x: input node features [N, D]
    edge_index: 2-dim indices of edges [2, E]
    return y_hat predicted logits [N, C]
    '''
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, num_heads=1, kernel='simple',
                 alpha=0.5, dropout=0.5, use_bn=True, use_residual=True, use_weight=True, use_graph=True):
        
        super(DIFFormer, self).__init__()
        
        # print("in:",in_channels)
        self.convs = nn.ModuleList()
        self.fcs = nn.ModuleList()
        self.fcs.append(nn.Linear(in_channels, hidden_channels))
        self.bns = nn.ModuleList()
        self.bns.append(nn.LayerNorm(hidden_channels))
        for i in range(num_layers):
            self.convs.append(
                DIFFormerConv(hidden_channels, hidden_channels, num_heads=num_heads, kernel=kernel, use_graph=use_graph, use_weight=use_weight))
            self.bns.append(nn.LayerNorm(hidden_channels))

        self.fcs.append(nn.Linear(hidden_channels, out_channels))

        self.dropout = dropout
        self.activation = F.relu
        self.use_bn = use_bn
        self.residual = use_residual
        self.alpha = alpha

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()
        for fc in self.fcs:
            fc.reset_parameters()

    def forward(self, x, edge_index, edge_weight=None):
        layer_ = []
        # print(x.shape)
        # print(edge_index)
        # input MLP layer
        x = self.fcs[0](x)
        # print(x.shape)
        # print(x)
        if self.use_bn:
            x = self.bns[0](x)
        x = self.activation(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # store as residual link
        layer_.append(x)

        for i, conv in enumerate(self.convs):
            # graph convolution with DIFFormer layer
            x = conv(x, x, edge_index, edge_weight)
            if self.residual:
                x = self.alpha * x + (1-self.alpha) * layer_[i]
            if self.use_bn:
                x = self.bns[i+1](x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            layer_.append(x)

        # output MLP layer
        x_out = self.fcs[-1](x)
        return x_out

    def get_attentions(self, x):
        layer_, attentions = [], []
        x = self.fcs[0](x)
        if self.use_bn:
            x = self.bns[0](x)
        x = self.activation(x)
        layer_.append(x)
        for i, conv in enumerate(self.convs):
            x, attn = conv(x, x, output_attn=True)
            attentions.append(attn)
            if self.residual:
                x = self.alpha * x + (1 - self.alpha) * layer_[i]
            if self.use_bn:
                x = self.bns[i + 1](x)
            layer_.append(x)
        return torch.stack(attentions, dim=0) # [layer num, N, N]


In [4]:
# def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, num_heads=1, kernel='simple',
#                  alpha=0.5, dropout=0.5, use_bn=True, use_residual=True, use_weight=True, use_graph=True):
device = torch.device("cuda:" + str(0)) if torch.cuda.is_available() else torch.device("cpu")
size = (128,32)
data = torch.randn(size).to(device)
print("数据集大小",data.shape)
# print(data)
size2 = (64,6)
data2 = torch.randn(size2).to(device)
print("图特征大小",data2.shape)
model=DIFFormer(data.shape[1],hidden_channels=512, out_channels=128, num_layers=2, alpha=0.5, dropout=0.5, num_heads=8, kernel="simple",
                       use_bn=True, use_residual=True, use_graph=False, use_weight=True).to(device)
out = model(data,None, None)
print(out.shape)

数据集大小 torch.Size([128, 32])
图特征大小 torch.Size([64, 6])
in: 32
torch.Size([128, 32])
None
torch.Size([128, 512])
tensor([[ 0.5305,  0.2127,  0.0791,  ..., -0.3409,  0.4752, -0.6582],
        [ 0.7388, -0.1885, -0.6768,  ...,  1.1323,  0.2417,  1.2532],
        [-0.1485,  0.3718, -0.4346,  ...,  0.0753, -0.1904,  0.7740],
        ...,
        [-0.2066, -0.1274, -1.4478,  ...,  0.5448, -0.6214,  0.5695],
        [ 0.9115,  0.4818, -1.0217,  ...,  1.4887, -1.1908,  1.4185],
        [-0.0199, -0.7302, -0.1347,  ...,  0.1038,  0.0548, -0.0278]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
torch.Size([128, 128])


In [76]:
class myformer(nn.Module):
    def __init__(self,
                 node_in_channels,
                 graph_in_channels,
                 dif_hidden_channels,
                 hidden_channels,
                 dif_out_channels,
                 out_channels,
                 num_layers=2,
                 num_heads=8,
                 alpha=0.5,
                 dropout=0.5,
                 use_bn=True,
                 use_residual=True,
                 use_graph=False,
                 use_weight=True):
        super(myformer,self).__init__()
        # print("in_channels=",in_channels)
        # print("out_channels=",out_channels)
        self.difformer = DIFFormer(node_in_channels,dif_hidden_channels, dif_out_channels, num_layers, alpha=alpha, dropout=dropout, num_heads=num_heads, kernel="simple",
                       use_bn=use_bn, use_residual=use_residual, use_graph=use_graph, use_weight=use_weight).to(device)
        
        self.fcs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.fcs.append(nn.Linear(dif_out_channels, hidden_channels))
        self.bns.append(nn.LayerNorm(hidden_channels))
        self.fcs.append(nn.Linear(graph_in_channels, hidden_channels))
        self.bns.append(nn.LayerNorm(hidden_channels))
        self.fcs.append(nn.Linear(hidden_channels, out_channels))
        self.bns.append(nn.LayerNorm(out_channels))
        
        self.cls_token = nn.Parameter(torch.randn(1, 128))					# nn.Parameter()定义可学习参数
        
        self.difformer2 = DIFFormer(dif_out_channels,dif_hidden_channels, dif_out_channels, 1, alpha=alpha, dropout=dropout, num_heads=num_heads, kernel="simple",
                       use_bn=use_bn, use_residual=use_residual, use_graph=True, use_weight=use_weight).to(device)
        
        
        self.activation = F.relu
        self.use_bn = use_bn
    def forward(self, x, now):
        n = x.size(0)
        print("num_node",n)
        print("input node shape:",x.shape)
        print("input graph shape:",now.shape)
        x = self.difformer(x,None,None)
        print("after difformer node shape:",x.shape)
        
        x = torch.cat((self.cls_token, x), dim=0)
        print("拼接后",x.shape)
        # 假设拼接的节点与其他所有节点相连
        edge_list = torch.zeros((2, n), dtype=torch.int64)
        edge_list[0] = torch.zeros(n, dtype=torch.int64)
        edge_list[1] = torch.arange(1, n+1, dtype=torch.int64)
        # row, col = edge_list
        # print(row)
        # print("edge_list",edge_list)
        
        x = self.difformer2(x,edge_list,None)
        # 取第一个向量表示整个数据集
        x = x[0]
        
        x = self.fcs[0](x)
        if self.use_bn:
            x = self.bns[0](x)
        now = self.fcs[1](now)
        if self.use_bn:
            now = self.bns[1](now)
            
        # 特征融合
        repeat_x = x.repeat(now.size(0), 1, 1)
        additional_now = now.unsqueeze(1)
        y = repeat_x + additional_now
        print("after merge:",y.shape)
        
        
        
        
        # y = y.mean(dim=2)
        # sum_tensor = torch.sum(y, dim=1)  # 在第二个维度上求和，结果大小为 (b, d)
        # mean_tensor = sum_tensor / x.size(1) 
        # print("after mean:",mean_tensor.shape)
        # result = self.fcs[-1](mean_tensor)
        # if self.use_bn:
        #     result = self.bns[-1](result)
        return y
    
    def reset_parameters(self):
        # for conv in self.convs:
        #     conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()
        for fc in self.fcs:
            fc.reset_parameters()
        self.difformer.reset_parameters()
        

In [74]:
# def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, num_heads=1, kernel='simple',
#                  alpha=0.5, dropout=0.5, use_bn=True, use_residual=True, use_weight=True, use_graph=True):
device = torch.device("cuda:" + str(0)) if torch.cuda.is_available() else torch.device("cpu")
size = (128,32)
data = torch.randn(size).to(device)
# print("数据集大小",data.shape)
# print(data)
size2 = (64,6)
data2 = torch.randn(size2).to(device)
# print("图特征大小",data2.shape)
model=myformer(node_in_channels=data.size(1),
               graph_in_channels=data2.size(1),
               dif_hidden_channels=512,
               hidden_channels=512,
               dif_out_channels=128,
               out_channels=1,
               num_layers=2,
               num_heads=8,
               alpha=0.5,
               dropout=0.5,
               use_bn=True,
               use_residual=True,
               use_graph=False,
               use_weight=True).to(device)
n = data.size(0)
# edge_list = torch.zeros((2, n), dtype=torch.int64)
# edge_list[0] = torch.zeros(n, dtype=torch.int64)
# edge_list[1] = torch.arange(1, n+1, dtype=torch.int64)
# # print("edge_list",edge_list)
# edge_list.to(device)
out = model(data,data2)
print(out.shape)

num_node 128
input node shape: torch.Size([128, 32])
input graph shape: torch.Size([64, 6])
after difformer node shape: torch.Size([128, 128])
拼接后 torch.Size([129, 128])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
129
edge_index torch.Size([2, 128])


RuntimeError: mat.device().is_cpu() INTERNAL ASSERT FAILED at "csrc/cpu/spmm_cpu.cpp":16, please report a bug to PyTorch. mat must be CPU tensor

In [28]:
import torch

# 假设有 2 组数据，每组数据有 3 个数据，每个数据有 4 个特征
b = 2
n = 3
d = 4

# 创建一个随机的 b*n*d 大小的张量作为示例数据
tensor_array = torch.randn(b, n, d)
print(tensor_array)
# 在 n 维度上求和取平均
sum_tensor = torch.sum(tensor_array, dim=1)  # 在第二个维度上求和，结果大小为 (b, d)
mean_tensor = sum_tensor / n  # 取平均，注意这里的 n 是每组数据的数量

print("Mean tensor shape:", mean_tensor.shape)
print("Mean tensor values:", mean_tensor)


tensor([[[ 1.8335, -0.0871, -1.3135,  1.1532],
         [ 0.5646, -0.2365, -0.7816,  0.0688],
         [-0.1525,  2.2362, -0.2618, -0.9286]],

        [[-1.9100, -0.0028, -0.0393,  0.2988],
         [-0.1047, -0.4419, -1.7085, -0.1470],
         [ 0.6696,  0.5725, -1.0773,  1.0037]]])
Mean tensor shape: torch.Size([2, 4])
Mean tensor values: tensor([[ 0.7485,  0.6375, -0.7856,  0.0978],
        [-0.4484,  0.0426, -0.9417,  0.3852]])
