In [1]:
import torch
import sys
sys.path.append('/home/zjy/project/MetaIM')
pwd = '/home/zjy/project/MetaIM/data'
device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=7)

In [2]:
from torch_geometric.datasets import Planetoid

cora_dataset = Planetoid(root=pwd+'/cora', name='cora')
data = cora_dataset[0]
edge_index = data.edge_index

In [3]:
import numpy as np
individual_infection_path = pwd+'/for_meta/cora_individual_infection_sir_200.npy'
seeds_infection_path = pwd+'/for_meta/cora_seed_infection_sir_200_sample_1000.npy'

individual_infection = np.load(individual_infection_path)
seeds_infection = np.load(seeds_infection_path)
individual_infection.shape,seeds_infection.shape

((2708, 2708), (1000, 2, 2708))

In [4]:
import torch
from torch_geometric.utils import to_scipy_sparse_matrix
import scipy.sparse as sp

# 转换为 scipy 稀疏矩阵
adj = to_scipy_sparse_matrix(edge_index)

adj = torch.Tensor(adj.toarray()).to_sparse()
adj = adj.to(device)


In [5]:
seed_num = int(seeds_infection[0][0].sum())

In [6]:
from torch.utils.data import Dataset, DataLoader, random_split


class CustomDataset(Dataset):
    def __init__(self, individual_infection,seeds_infection):
        self.individual_infection = individual_infection
        self.seeds_infection = seeds_infection

    def __len__(self):
        return len(self.seeds_infection)

    def __getitem__(self, idx):
        
        return self.seeds_infection[idx][0], self.seeds_infection[idx][1]

dataset = CustomDataset(individual_infection, seeds_infection)

In [7]:
# 定义划分比例
train_ratio = 0.8
test_ratio = 0.2

# 划分数据集
train_dataset, test_dataset = random_split(dataset, [int(len(dataset)*train_ratio), int(len(dataset)*test_ratio)])

train_batch_size = 32
test_batch_size = 2

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, drop_last=False)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

训练VAE

In [8]:
# from data import model 
from data.model.model import VAEModel, Encoder, Decoder
from torch.optim import Adam, SGD
import torch.nn.functional as F

# # hidden_dim = 256
# # latent_dim = 64
hidden_dim = 1024
latent_dim = 128

encoder = Encoder(input_dim= len(seeds_infection[0][0]), 
                  hidden_dim=hidden_dim, 
                  latent_dim=latent_dim)

decoder = Decoder(input_dim=latent_dim, 
                  latent_dim=latent_dim, 
                  hidden_dim=hidden_dim, 
                  output_dim=len(seeds_infection[0][0]))

vae_model = VAEModel(Encoder=encoder, Decoder=decoder).to(device)

optimizer_vae = Adam([{'params': vae_model.parameters()}], 
                 lr=1e-3)
vae_model.train()

VAEModel(
  (Encoder): Encoder(
    (FC_input): Linear(in_features=2708, out_features=1024, bias=True)
    (FC_input2): Linear(in_features=1024, out_features=1024, bias=True)
    (FC_output): Linear(in_features=1024, out_features=128, bias=True)
    (bn): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Decoder): Decoder(
    (FC_input): Linear(in_features=128, out_features=128, bias=True)
    (FC_hidden_1): Linear(in_features=128, out_features=1024, bias=True)
    (FC_hidden_2): Linear(in_features=1024, out_features=1024, bias=True)
    (FC_output): Linear(in_features=1024, out_features=2708, bias=True)
  )
)

In [9]:
for epoch in range(250):
    train_vae_loss = 0
    mean_train_accuracy = 0
    for batch_idx, seeds_label in enumerate(train_loader):        
        x = seeds_label[0].to(device)
        optimizer_vae.zero_grad()
        loss = 0
        for i, x_i in enumerate(x):
            x_hat = vae_model(x_i)

            reproduction_loss = F.binary_cross_entropy(x_hat, x_i, reduction='sum')   
            loss += reproduction_loss    
        train_vae_loss += loss.item()
        loss = loss/x.size(0)
        loss.backward()
        optimizer_vae.step()
        
    print("Epoch: {}".format(epoch+1), 
        "\tTrain_vae_loss: {:.4f}".format(train_vae_loss / train_batch_size),
        )

Epoch: 1 	Train_vae_loss: 21098.3594


Epoch: 2 	Train_vae_loss: 14003.8514
Epoch: 3 	Train_vae_loss: 13613.8539
Epoch: 4 	Train_vae_loss: 13556.5552
Epoch: 5 	Train_vae_loss: 13527.4182
Epoch: 6 	Train_vae_loss: 13517.6142
Epoch: 7 	Train_vae_loss: 13510.1918
Epoch: 8 	Train_vae_loss: 13508.2707
Epoch: 9 	Train_vae_loss: 13504.1915
Epoch: 10 	Train_vae_loss: 13495.8273
Epoch: 11 	Train_vae_loss: 13488.0256
Epoch: 12 	Train_vae_loss: 13488.4356
Epoch: 13 	Train_vae_loss: 13482.6939
Epoch: 14 	Train_vae_loss: 13474.5729
Epoch: 15 	Train_vae_loss: 13477.4797
Epoch: 16 	Train_vae_loss: 13474.8606
Epoch: 17 	Train_vae_loss: 13471.4833
Epoch: 18 	Train_vae_loss: 13466.6879
Epoch: 19 	Train_vae_loss: 13468.1844
Epoch: 20 	Train_vae_loss: 13461.6968
Epoch: 21 	Train_vae_loss: 13462.1439
Epoch: 22 	Train_vae_loss: 13459.3940
Epoch: 23 	Train_vae_loss: 13459.5847
Epoch: 24 	Train_vae_loss: 13454.3450
Epoch: 25 	Train_vae_loss: 13457.1714
Epoch: 26 	Train_vae_loss: 13454.4913
Epoch: 27 	Train_vae_loss: 13454.6875
Epoch: 28 	Train_vae

In [10]:
import torch
import torch.nn as nn

from torch_geometric.nn import GATConv
from torch.optim import Adam, SGD
import torch.nn.functional as F

class GAT(nn.Module):
    def __init__(self, seeds_dim, inflect_dim, hidden_channels, out_channels, num_heads):
        super(GAT, self).__init__()
        self.linear1 = nn.Linear(seeds_dim + inflect_dim, seeds_dim + inflect_dim)
        self.bn1 = nn.BatchNorm1d(seeds_dim + inflect_dim)
        self.conv1 = GATConv(seeds_dim + inflect_dim, hidden_channels, heads=num_heads)
        self.bn2 = nn.BatchNorm1d(hidden_channels * num_heads)
        self.conv2 = GATConv(hidden_channels * num_heads, hidden_channels * num_heads, heads=1)
        self.bn3 = nn.BatchNorm1d(seeds_dim)
        self.linear2 = nn.Linear(hidden_channels * num_heads + seeds_dim, out_channels)

    def forward(self, seeds_i, inflect_i, edge_index):
        x =  torch.cat((seeds_i, inflect_i), dim=-1)
        x = self.linear1(x)
        x = self.bn1(x)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.bn2(x)
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        seeds_i = self.bn3(seeds_i)
        x =  torch.cat((x, seeds_i), dim=-1)
        x = self.linear2(x)
        return F.relu(x)

In [11]:
inflect_dim = latent_dim
seeds_dim = latent_dim

forward_model = GAT(seeds_dim,inflect_dim, 512, 1, 4)

optimizer = Adam([{'params': forward_model.parameters()}], 
                 lr=0.0001)

adj = adj.to(device)
forward_model = forward_model.to(device)
forward_model.train()

GAT(
  (linear1): Linear(in_features=256, out_features=256, bias=True)
  (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv1): GATConv(256, 512, heads=4)
  (bn2): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): GATConv(2048, 2048, heads=1)
  (bn3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear2): Linear(in_features=2176, out_features=1, bias=True)
)

In [12]:
for param in vae_model.parameters():
    param.requires_grad = False 
encoder = vae_model.Encoder

In [13]:
inflected = torch.tensor(individual_infection).T.to(device)
encode_inflected = torch.zeros(inflected.shape[0], latent_dim)
for i in range(inflected.shape[0]):
    encode_inflected_i  = encoder(inflected[i])
    encode_inflected[i] = encode_inflected_i
encode_inflected = encode_inflected.detach().to(device)
encode_inflected

tensor([[ 2.4010, -1.7477, -2.3545,  ..., -2.1918, -1.4849, -0.0159],
        [ 0.7068, -1.4042, -1.1411,  ..., -0.6180, -0.5700, -0.0185],
        [ 5.2496, -1.9671, -4.9019,  ..., -2.3424, -3.9156,  0.4494],
        ...,
        [ 0.9739, -0.7837, -0.3581,  ..., -0.1356, -0.0480, -0.3513],
        [ 3.4649, -1.8274, -2.7977,  ..., -3.4566, -1.0626, -0.1881],
        [ 4.0444, -1.4717, -3.2402,  ..., -1.7792, -2.4938,  0.0555]],
       device='cuda:7')

In [14]:
edge_index = edge_index.to(device)
top_num = 500


for epoch in range(2000):

    total_loss = 0
    

    total_train_accuracy = 0
    
    count_train = 0
    for batch_idx, seeds_label in enumerate(train_loader): 
        count_train += 1
        forward_loss = 0 
        seeds =  seeds_label[0].to(device)     
        labels = seeds_label[1].to(device)
        optimizer.zero_grad()
        
        loss = 0
        train_accuracy = 0
        for i, seeds_i in enumerate(seeds):
            

            
            seeds_i = encoder(seeds_i).detach()
            seeds_i = seeds_i.expand(seeds.shape[1], -1)
            
            y_i = labels[i]
            y_hat = forward_model(seeds_i, encode_inflected, edge_index)
            
            
            
            _, top_indices_true = torch.topk(y_i.clone(), top_num)
            label_2 = torch.zeros(y_i.shape).to(device)
            label_2[top_indices_true] = 1
            
            _, top_indices_predict = torch.topk(y_hat.clone().squeeze(-1), top_num)
            
            # 将张量数组转换为Python列表
            list1 = top_indices_true.tolist()
            list_pre = top_indices_predict.tolist()

            # 使用集合操作找到交集
            intersection = list(set(list1) & set(list_pre))
            accuracy_i = len(intersection) / top_num       
            train_accuracy += accuracy_i 

            forward_loss = 0.5*F.mse_loss(y_hat.squeeze(-1), y_i, reduction='sum') + F.mse_loss(y_hat.squeeze(-1), label_2, reduction='sum')    
            loss += forward_loss    
        
           
        train_accuracy /= seeds.size(0)
        total_train_accuracy += train_accuracy
        loss = loss/seeds.size(0)
        total_loss += loss.item() 
        loss.backward()
        optimizer.step()
        # for p in forward_model.parameters():
        #     p.data.clamp_(min=0)
        

    print("Epoch: {}".format(epoch+1), 
        "\tTotal: {:.4f}".format(total_loss / count_train),
        "\tMean_train_accuracy: {:.4f}".format(total_train_accuracy/ count_train),
        )  
    
    total_test_accuracy = 0
    
    count_test = 0

    for batch_idx, seeds_label in enumerate(test_loader): 
        count_test += 1 
        seeds =  seeds_label[0].to(device)     
        labels = seeds_label[1].to(device)
        test_accuracy = 0
        for i, seeds_i in enumerate(seeds):
            seeds_i = encoder(seeds_i).detach()
            seeds_i = seeds_i.expand(seeds.shape[1], -1)
            
            y_i = labels[i]
            
            y_hat = forward_model(seeds_i, encode_inflected, edge_index)
            
            _, top_indices_true = torch.topk(y_i, top_num)
            
            
            _, top_indices_predict = torch.topk(y_hat.squeeze(-1), top_num)

            
            # 将张量数组转换为Python列表
            list1 = top_indices_true.tolist()
            list_pre = top_indices_predict.tolist()
            

            # 使用集合操作找到交集
            intersection = list(set(list1) & set(list_pre))

            
            accuracy_i = len(intersection) / top_num       
            test_accuracy += accuracy_i 
        test_accuracy /= len(seeds)
        total_test_accuracy += test_accuracy
        

    print(
        "\tMean_test_accuracy: {:.4f}".format(total_test_accuracy / count_test),
        )  

    

Epoch: 1 	Total: 621.8201 	Mean_train_accuracy: 0.2085
	Mean_test_accuracy: 0.2071
Epoch: 2 	Total: 579.0218 	Mean_train_accuracy: 0.2054
	Mean_test_accuracy: 0.2085
Epoch: 3 	Total: 578.3410 	Mean_train_accuracy: 0.2064
	Mean_test_accuracy: 0.2079
Epoch: 4 	Total: 577.7807 	Mean_train_accuracy: 0.2067
	Mean_test_accuracy: 0.2090
Epoch: 5 	Total: 577.3220 	Mean_train_accuracy: 0.2064
	Mean_test_accuracy: 0.2082
Epoch: 6 	Total: 576.1103 	Mean_train_accuracy: 0.2048
	Mean_test_accuracy: 0.2094
Epoch: 7 	Total: 566.6875 	Mean_train_accuracy: 0.1883
	Mean_test_accuracy: 0.1823
Epoch: 8 	Total: 554.5610 	Mean_train_accuracy: 0.1903
	Mean_test_accuracy: 0.2018
Epoch: 9 	Total: 542.2965 	Mean_train_accuracy: 0.2185
	Mean_test_accuracy: 0.2435
Epoch: 10 	Total: 504.7531 	Mean_train_accuracy: 0.3112
	Mean_test_accuracy: 0.3827
Epoch: 11 	Total: 451.4218 	Mean_train_accuracy: 0.4239
	Mean_test_accuracy: 0.4520
Epoch: 12 	Total: 422.1953 	Mean_train_accuracy: 0.4821
	Mean_test_accuracy: 0.5031
E