In [11]:
import torch
pwd = '/home/zjy/project/MetaIM/data'
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=2)

In [12]:
from torch_geometric.datasets import Planetoid

cora_dataset = Planetoid(root=pwd+'/cora', name='cora')
data = cora_dataset[0]
edge_index = data.edge_index

In [13]:
import numpy as np
individual_infection_path = pwd+'/for_meta/cora_individual_infection_sir.npy'
seeds_infection_path = pwd+'/for_meta/cora_seed_infection_sir.npy'

individual_infection = np.load(individual_infection_path)
seeds_infection = np.load(seeds_infection_path)
individual_infection.shape,seeds_infection.shape

((2708, 2708), (500, 2, 2708))

In [14]:
import torch
from torch_geometric.utils import to_scipy_sparse_matrix
import scipy.sparse as sp

# 转换为 scipy 稀疏矩阵
adj = to_scipy_sparse_matrix(edge_index)


# def normalize_adj(mx):
#     """Row-normalize sparse matrix"""
#     rowsum = np.array(mx.sum(1))
#     r_inv_sqrt = np.power(rowsum, -0.5).flatten()
#     r_inv_sqrt[np.isinf(r_inv_sqrt)] = 0.
#     r_mat_inv_sqrt = sp.diags(r_inv_sqrt)
#     return mx.dot(r_mat_inv_sqrt).transpose().dot(r_mat_inv_sqrt)



# adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
# adj = normalize_adj(adj + sp.eye(adj.shape[0]))
adj = torch.Tensor(adj.toarray()).to_sparse()
adj


tensor(indices=tensor([[   0,    0,    0,  ..., 2707, 2707, 2707],
                       [ 633, 1862, 2582,  ...,  598, 1473, 2706]]),
       values=tensor([1., 1., 1.,  ..., 1., 1., 1.]),
       size=(2708, 2708), nnz=10556, layout=torch.sparse_coo)

In [15]:
from torch import nn 
class SpecialSpmmFunction(torch.autograd.Function):
    """Special function for only sparse region backpropataion layer."""
    @staticmethod
    def forward(ctx, indices, values, shape, b):
        assert indices.requires_grad == False
        a = torch.sparse_coo_tensor(indices, values, shape)
        ctx.save_for_backward(a, b)
        ctx.N = shape[0]
        return torch.matmul(a, b)

    @staticmethod
    def backward(ctx, grad_output):
        a, b = ctx.saved_tensors
        grad_values = grad_b = None
        if ctx.needs_input_grad[1]:
            grad_a_dense = grad_output.matmul(b.t())
            edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :]
            grad_values = grad_a_dense.view(-1)[edge_idx]
        if ctx.needs_input_grad[3]:
            grad_b = a.t().matmul(grad_output)
        return None, grad_values, None, grad_b


class SpecialSpmm(nn.Module):
    def forward(self, indices, values, shape, b):
        return SpecialSpmmFunction.apply(indices, values, shape, b)

In [16]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class SpGraphAttentionLayer(nn.Module):
    """
    Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903
    """

    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(SpGraphAttentionLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        nn.init.xavier_normal_(self.W.data, gain=1.414)
                
        self.a = nn.Parameter(torch.zeros(size=(1, 2*out_features)))
        nn.init.xavier_normal_(self.a.data, gain=1.414)

        self.dropout = nn.Dropout(dropout)
        self.leakyrelu = nn.LeakyReLU(self.alpha)
        self.special_spmm = SpecialSpmm()

    def forward(self, input, adj):
        dv = input.device

        N = input.size()[0]
        if adj.layout == torch.sparse_coo:
            edge = adj.indices()
        else:
            edge = adj.nonzero().t()

        assert not torch.isnan(input).any()

        h = torch.mm(input, self.W)
        # h: N x out
        assert not torch.isnan(h).any()

        # Self-attention on the nodes - Shared attention mechanism
        edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t()
        # edge: 2*D x E

        edge_e = torch.exp(-self.leakyrelu(self.a.mm(edge_h).squeeze()))
        assert not torch.isnan(edge_e).any()
        # edge_e: E

        e_rowsum = self.special_spmm(edge, edge_e, torch.Size([N, N]), torch.ones(size=(N,1), device=dv))
        # e_rowsum: N x 1

        edge_e = self.dropout(edge_e)
        # edge_e: E

        h_prime = self.special_spmm(edge, edge_e, torch.Size([N, N]), h)
        assert not torch.isnan(h_prime).any()
        # h_prime: N x out
        
        h_prime = h_prime.div(e_rowsum)
        # h_prime: N x out
        assert not torch.isnan(h_prime).any()

        if self.concat:
            # if this layer is not last layer,
            return F.elu(h_prime)
        else:
            # if this layer is last layer,
            return h_prime

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'

In [17]:
class SpGAT(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
        """Sparse version of GAT."""
        super(SpGAT, self).__init__()
        self.dropout = dropout

        self.attentions = [SpGraphAttentionLayer(nfeat, 
                                                 nhid, 
                                                 dropout=dropout, 
                                                 alpha=alpha, 
                                                 concat=True) for _ in range(nheads)]
        
        self.attentions1 = [SpGraphAttentionLayer(nhid * nheads, 
                                                 nhid, 
                                                 dropout=dropout, 
                                                 alpha=alpha, 
                                                 concat=True) for _ in range(nheads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)
            
        for i, attention in enumerate(self.attentions1):
            self.add_module('attention1_{}'.format(i), attention)

        self.out_att = SpGraphAttentionLayer(nhid * nheads, 
                                             nclass, 
                                             dropout=dropout, 
                                             alpha=alpha, 
                                             concat=False)
        

    def forward(self, x, adj):
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.elu(torch.cat([att(x, adj) for att in self.attentions], dim=1))
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.elu(self.out_att(x, adj))
        return x

In [18]:
# from torch_geometric.nn import GATConv
# class GAT(torch.nn.Module):
#     def __init__(self, in_channels, hidden_channels, out_channels, num_heads):
#         super(GAT, self).__init__()
#         self.conv1 = GATConv(in_channels, hidden_channels, heads=num_heads)
#         self.conv2 = GATConv(hidden_channels * num_heads, out_channels, heads=1)

#     def forward(self, x, edge_index):
#         x = F.elu(self.conv1(x, edge_index))
#         x = F.dropout(x, p=0.5, training=self.training)
#         x = self.conv2(x, edge_index)
#         return x

In [19]:
feat_num = int(seeds_infection[0][0].sum())

In [20]:
from torch.utils.data import Dataset, DataLoader, random_split


class CustomDataset(Dataset):
    def __init__(self, individual_infection,seeds_infection, feat_num):
        self.individual_infection = individual_infection
        self.seeds_infection = seeds_infection
        self.feat_shape = (len(individual_infection), feat_num)

    def __len__(self):
        return len(self.seeds_infection)

    def __getitem__(self, idx):
        seeds= np.nonzero(self.seeds_infection[idx][0])[0]
        
        feature = torch.zeros(self.feat_shape[0],self.feat_shape[1])
        for i in range(len(seeds)):
            seed_i_infection = torch.tensor(self.individual_infection[seeds[i]])
            feature[:, i] = seed_i_infection
            
        label = self.seeds_infection[idx][1]
        
        return feature, label

dataset = CustomDataset(individual_infection,seeds_infection,feat_num)

In [21]:
# 定义划分比例
train_ratio = 0.8
test_ratio = 0.2

# 划分数据集
train_dataset, test_dataset = random_split(dataset, [int(len(dataset)*train_ratio), int(len(dataset)*test_ratio)])

train_batch_size = 64
test_batch_size = 4

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, drop_last=False)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv

class GAT(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_heads):
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=num_heads)
        self.conv2 = GATConv(hidden_channels * num_heads, out_channels, heads=1)

    def forward(self, x, edge_index):
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv2(x, edge_index)
        return F.elu(x)

# class SpGAT_pyg(nn.Module):
#     def __init__(self, in_channels, hidden_channels, out_channels, dropout, alpha, heads):
#         super(SpGAT_pyg, self).__init__()
#         self.dropout = dropout
#         self.attentions = nn.ModuleList()
#         self.attentions1 = nn.ModuleList()

#         for _ in range(heads):
#             self.attentions.append(GATConv(in_channels, hidden_channels, heads=1, dropout=dropout, concat=True))
        
#         for _ in range(heads):
#             self.attentions1.append(GATConv(hidden_channels * heads, hidden_channels, heads=1, dropout=dropout, concat=True))

#         self.out_att = GATConv(hidden_channels * heads, out_channels, heads=1, dropout=dropout, concat=False)

#     def forward(self, x, edge_index):
#         x = F.dropout(x, p=self.dropout, training=self.training)
#         x = torch.cat([att(x, edge_index) for att in self.attentions], dim=1)
#         x = F.elu(x)
#         x = F.dropout(x, p=self.dropout, training=self.training)
#         x = torch.cat([att(x, edge_index) for att in self.attentions1], dim=1)
#         x = F.elu(x)
#         x = F.dropout(x, p=self.dropout, training=self.training)
#         x = self.out_att(x, edge_index)
#         return x

In [29]:
# from data.model.gat import GAT, SpGAT
from torch.optim import Adam, SGD


# forward_model = SpGAT(nfeat=feat_num, 
#                 nhid=64, 
#                 nclass=1, 
#                 dropout=0.2, 
#                 nheads=1, 
#                 alpha=0.2)
forward_model = GAT(feat_num, 256, 1, 4)

optimizer = Adam([{'params': forward_model.parameters()}], 
                 lr=1e-3)

adj = adj.to(device)
forward_model = forward_model.to(device)
forward_model.train()

GAT(
  (conv1): GATConv(135, 256, heads=4)
  (conv2): GATConv(1024, 1, heads=1)
)

In [30]:
edge_index = edge_index.to(device)
top_num = 300

for epoch in range(1200):

    total_overall = 0
    forward_loss = 0

    mean_train_accuracy = 0
    for batch_idx, feature_label in enumerate(train_loader):        
        features = feature_label[0].to(device)
        labels = feature_label[1].to(device)
        optimizer.zero_grad()
        
        loss = 0
        train_accuracy = 0
        for i, x_i in enumerate(features):
            y_i = labels[i]
            y_hat = forward_model(x_i, edge_index)
            _, top_indices_true = torch.topk(y_i.clone(), top_num)
            
            _, top_indices_predict = torch.topk(y_hat.clone().squeeze(-1), top_num)
            
            # 将张量数组转换为Python列表
            list1 = top_indices_true.tolist()
            list_pre = top_indices_predict.tolist()

            # 使用集合操作找到交集
            intersection = list(set(list1) & set(list_pre))
            accuracy_i = len(intersection) / top_num       
            train_accuracy += accuracy_i 

            forward_loss = F.mse_loss(y_hat.squeeze(-1), y_i, reduction='sum')        
            loss += forward_loss    
        
        total_overall += loss.item()    
        train_accuracy /= len(features)
        mean_train_accuracy = train_accuracy
        loss.backward()
        optimizer.step()
        # for p in forward_model.parameters():
        #     p.data.clamp_(min=0)
        
        
    print("Epoch: {}".format(epoch+1), 
        "\tTotal: {:.4f}".format(total_overall / train_batch_size),
        "\tMean_train_accuracy: {:.4f}".format(mean_train_accuracy),
        )  
    
    mean_accuracy = 0
    mean_accuracy_sum = 0

    
    for batch_idx, feature_label in enumerate(test_loader):   
        features = feature_label[0].to(device)
        labels = feature_label[1].to(device)
        
        accuracy = 0
        accuracy_sum = 0
        
        for i, x_i in enumerate(features):
            y_i = labels[i]
            _, top_indices_true = torch.topk(y_i, top_num)
            
            y_hat = forward_model(x_i, edge_index)
            
            _, top_indices_predict = torch.topk(y_hat.squeeze(-1), top_num)
            
            sum_pre = torch.sum(x_i, dim=1, keepdim=True)
            _, top_indices_sum = torch.topk(sum_pre.squeeze(-1), top_num)
            
            # 将张量数组转换为Python列表
            list1 = top_indices_true.tolist()
            list_pre = top_indices_predict.tolist()
            
            list_sum = top_indices_sum.tolist()

            # 使用集合操作找到交集
            intersection = list(set(list1) & set(list_pre))
            
            intersection_sum = list(set(list1) & set(list_sum))
            
            accuracy_i = len(intersection) / top_num       
            accuracy += accuracy_i 
            accuracy_sum += len(intersection_sum) / top_num  
        accuracy /= test_batch_size
        accuracy_sum/= test_batch_size
        mean_accuracy = accuracy
        mean_accuracy_sum = accuracy_sum
        break
    
    print(
        "\tMean_test_accuracy: {:.4f}".format(mean_accuracy),
        "\tMean_test_accuracy_sum: {:.4f}".format(mean_accuracy_sum)
        )  

    

Epoch: 1 	Total: 945.0693 	Mean_train_accuracy: 0.1750
	Mean_test_accuracy: 0.2108 	Mean_test_accuracy_sum: 0.3917
Epoch: 2 	Total: 822.2002 	Mean_train_accuracy: 0.2817
	Mean_test_accuracy: 0.2850 	Mean_test_accuracy_sum: 0.3917
Epoch: 3 	Total: 798.1047 	Mean_train_accuracy: 0.2748
	Mean_test_accuracy: 0.2975 	Mean_test_accuracy_sum: 0.3917
Epoch: 4 	Total: 795.1353 	Mean_train_accuracy: 0.2569
	Mean_test_accuracy: 0.2908 	Mean_test_accuracy_sum: 0.3917
Epoch: 5 	Total: 786.5519 	Mean_train_accuracy: 0.2902
	Mean_test_accuracy: 0.3017 	Mean_test_accuracy_sum: 0.3917
Epoch: 6 	Total: 782.3633 	Mean_train_accuracy: 0.2715
	Mean_test_accuracy: 0.3000 	Mean_test_accuracy_sum: 0.3917
Epoch: 7 	Total: 781.0003 	Mean_train_accuracy: 0.2806
	Mean_test_accuracy: 0.3042 	Mean_test_accuracy_sum: 0.3917
Epoch: 8 	Total: 780.3975 	Mean_train_accuracy: 0.2958
	Mean_test_accuracy: 0.3083 	Mean_test_accuracy_sum: 0.3917
Epoch: 9 	Total: 778.4504 	Mean_train_accuracy: 0.2842
	Mean_test_accuracy: 0.31