In [8]:
import pickle
import argparse

device='cpu'

parser = argparse.ArgumentParser(description="GenIM")
datasets = ['jazz', 'cora_ml', 'power_grid', 'netscience', 'random5']
parser.add_argument("-d", "--dataset", default="cora_ml", type=str,
                    help="one of: {}".format(", ".join(sorted(datasets))))
diffusion = ['IC', 'LT', 'SIS']
parser.add_argument("-dm", "--diffusion_model", default="LT", type=str,
                    help="one of: {}".format(", ".join(sorted(diffusion))))
seed_rate = [1, 5, 10, 20]
parser.add_argument("-sp", "--seed_rate", default=1, type=int,
                    help="one of: {}".format(", ".join(str(sorted(seed_rate)))))
mode = ['Normal', 'Budget Constraint']
parser.add_argument("-m", "--mode", default="normal", type=str,
                    help="one of: {}".format(", ".join(sorted(mode))))
args = parser.parse_args(args=[])


with open('data/' + args.dataset + '_mean_' + args.diffusion_model + str(50) + '.SG', 'rb') as f:
    graph = pickle.load(f)

  graph = pickle.load(f)


### 开始训练

In [9]:
import pickle

with open(f'cora_{args.diffusion_model}_50.SG', 'rb') as f:
    graph = pickle.load(f)

In [10]:
import scipy.sparse as sp
from torch.utils.data import DataLoader
import numpy as np
import torch

def normalize_adj(mx):
    """Row-normalize sparse matrix"""
    rowsum = np.array(mx.sum(1))
    r_inv_sqrt = np.power(rowsum, -0.5).flatten()
    r_inv_sqrt[np.isinf(r_inv_sqrt)] = 0.
    r_mat_inv_sqrt = sp.diags(r_inv_sqrt)
    return mx.dot(r_mat_inv_sqrt).transpose().dot(r_mat_inv_sqrt)

batch_size = 16
hidden_dim = 1024
latent_dim = 512

adj, inverse_pairs = graph['adj'], graph['inverse_pairs']
adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
adj = normalize_adj(adj + sp.eye(adj.shape[0]))
adj = torch.Tensor(adj.toarray()).to_sparse()

train_set, test_set = torch.utils.data.random_split(inverse_pairs, 
                                                    [len(inverse_pairs)-batch_size, 
                                                     batch_size])
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, drop_last=False)
test_loader  = DataLoader(dataset=test_set,  batch_size=1, shuffle=False)

In [11]:
from main.model.gat import GAT, SpGAT
from torch.optim import Adam, SGD


forward_model = SpGAT(nfeat=1, 
                nhid=64, 
                nclass=1, 
                dropout=0.1, 
                nheads=4, 
                alpha=0.2)

optimizer = Adam([{'params': forward_model.parameters()}], 
                 lr=1e-3)

adj = adj.to(device)
forward_model = forward_model.to(device)
forward_model.train()

SpGAT(
  (attention_0): SpGraphAttentionLayer (1 -> 64)
  (attention_1): SpGraphAttentionLayer (1 -> 64)
  (attention_2): SpGraphAttentionLayer (1 -> 64)
  (attention_3): SpGraphAttentionLayer (1 -> 64)
  (attention1_0): SpGraphAttentionLayer (256 -> 64)
  (attention1_1): SpGraphAttentionLayer (256 -> 64)
  (attention1_2): SpGraphAttentionLayer (256 -> 64)
  (attention1_3): SpGraphAttentionLayer (256 -> 64)
  (out_att): SpGraphAttentionLayer (256 -> 1)
)

In [12]:
# import torch.nn.functional as F
# for epoch in range(600):
#     # 训练
#     loss = 0
#     for t in range(80):
#         y_hat = forward_model(x[t].unsqueeze(-1), adj)
#         forward_loss = F.binary_cross_entropy(y_hat.squeeze(-1), y[t], reduction='sum')    
#         loss += forward_loss
#     loss /= 80
#     loss.backward()
#     optimizer.step()
#     for p in forward_model.parameters():
#         p.data.clamp_(min=0)
    
#     # 测试
#     with torch.no_grad():
#         test_loss = 0
#         correct = 0
#         correct_1 = 0
#         for t in range(80,100):
#             y_hat = forward_model(x[t].unsqueeze(-1), adj)
#             # forward_loss = F.binary_cross_entropy(y_hat.squeeze(-1), y[t], reduction='sum')    
#             forward_loss = F.mse_loss(y_hat.squeeze(-1), y[t], reduction='sum')
#             test_loss += forward_loss
            
#             threshold = 0.6
#             y_pre = y_hat.squeeze(-1)
#             filtered_y_hat = (y_pre > threshold).float()
#             correct += ((filtered_y_hat == y[t]).sum()/len(y[t]))
#             count_both_ones = torch.sum((filtered_y_hat == 1) & (y[t] == 1))
#             correct_1 += count_both_ones/y[t].sum()
#         correct /= 20
#         correct_1 /= 20    
#         test_loss /= 20
#     print("Epoch: {}".format(epoch+1), 
#         "\tTotal: {:.4f}".format(loss),
#         "\tTest_loss: {:.4f}".format(test_loss),
#         "\tTest_accuracy: {:.4f}".format(correct),
#         "\tTest_accuracy_1: {:.4f}".format(correct_1),
#         # "\tTime: {:.4f}".format(end - begin)
#         )
        

In [15]:
import torch.nn.functional as F
def loss_all(y, y_hat):
    # reproduction_loss = F.binary_cross_entropy(x_hat, x, reduction='sum')
    forward_loss = F.mse_loss(y_hat, y, reduction='sum')
    # print(f'foward_loss====={forward_loss}')
    # forward_loss = F.binary_cross_entropy(y_hat, y, reduction='sum')
    # return reproduction_loss+forward_loss, reproduction_loss, forward_loss
    
    threshold = 0.3
    filtered_y_hat = (y_hat > threshold).float()
    correct = ((filtered_y_hat == y).sum()/len(y))
    count_both_ones = torch.sum((filtered_y_hat == 1) & (y == 1))
    correct_1 = count_both_ones/y.sum()
    
    return forward_loss,correct,correct_1

In [16]:
import time


for epoch in range(600):
    begin = time.time()
    train_total_loss = 0
    train_total_correct = 0
    train_total_correct_1 = 0

    for batch_idx, data_pair in enumerate(train_loader):
        # input_pair = torch.cat((data_pair[:, :, 0], data_pair[:, :, 1]), 1).to(device)
        
        x = data_pair[:, :, 0].float().to(device)
        y = data_pair[:, :, 1].float().to(device)
        optimizer.zero_grad()
        
        train_loss = 0
        train_correct = 0
        train_correct_1 = 0
        
        for i, x_i in enumerate(x):
            y_i = y[i]
            y_hat = forward_model(x_i.unsqueeze(-1), adj)
            y_hat = y_hat.squeeze(-1)
            forward_loss,correct_t,correct_1_t= loss_all(y_i, y_hat.squeeze(-1))
            train_correct += correct_t
            train_correct_1 += correct_1_t
            train_loss += forward_loss    
         
        train_total_loss += train_loss.item()   
        train_loss = train_loss/x.size(0)  
        # train_total_loss += train_loss.item()
        train_loss.backward()
        optimizer.step()
        
        # 训练准确度
        
        # train_correct = train_correct/x.size(0) 
        # train_correct_1 = train_correct_1/x.size(0) 
        train_total_correct += train_correct
        train_total_correct_1 += train_correct_1
        
        for p in forward_model.parameters():
            p.data.clamp_(min=0)
        
    # 在测试集上进行评估
    test_total_loss = 0
    test_total_correct = 0
    test_total_correct_1 = 0
    
    # total = 0
    with torch.no_grad():
        for batch_idx, data_pair in enumerate(test_loader):
            x = data_pair[:, :, 0].float().to(device)
            y = data_pair[:, :, 1].float().to(device)
            test_loss = 0.0
            test_correct = 0
            test_correct_1 = 0
            
            for i, x_i in enumerate(x):
                y_i = y[i]
                y_hat = forward_model(x_i.unsqueeze(-1), adj)
                y_hat = y_hat.squeeze(-1)
                forward_loss,correct_t,correct_1_t= loss_all(y_i, y_hat.squeeze(-1))
                test_correct += correct_t
                test_correct_1 += correct_1_t
                test_loss += forward_loss
            
            test_total_loss += test_loss
            test_total_correct += test_correct
            test_total_correct_1 += test_correct_1      
    
            
    end = time.time()
    print("Epoch: {}".format(epoch+1), 
          "\tTrain_loss: {:.4f}".format(train_total_loss / len(train_set)),
          "\tTrain_accuracy: {:.4f}".format(train_total_correct / len(train_set)),
          "\tTrain_accuracy_1: {:.4f}".format(train_total_correct_1 / len(train_set)),
          "\tTest_loss: {:.4f}".format(test_total_loss / len(test_set)),
          "\tTest_accuracy: {:.4f}".format(test_total_correct / len(test_set)),
          "\tTest_accuracy_1: {:.4f}".format(test_total_correct_1 / len(test_set)),
          "\tTime: {:.4f}".format(end - begin)
         )

Epoch: 1 	Train_loss: 839.5335 	Train_accuracy: 0.6967 	Train_accuracy_1: 0.5163 	Test_loss: 745.6663 	Test_accuracy: 0.7448 	Test_accuracy_1: 0.5749 	Time: 5.9124
Epoch: 2 	Train_loss: 796.1602 	Train_accuracy: 0.7231 	Train_accuracy_1: 0.5694 	Test_loss: 717.0917 	Test_accuracy: 0.7566 	Test_accuracy_1: 0.6118 	Time: 5.7136
Epoch: 3 	Train_loss: 763.3956 	Train_accuracy: 0.7404 	Train_accuracy_1: 0.6102 	Test_loss: 679.6266 	Test_accuracy: 0.7770 	Test_accuracy_1: 0.6588 	Time: 5.8065
Epoch: 4 	Train_loss: 737.5449 	Train_accuracy: 0.7556 	Train_accuracy_1: 0.6427 	Test_loss: 662.6205 	Test_accuracy: 0.7830 	Test_accuracy_1: 0.6822 	Time: 5.8709
Epoch: 5 	Train_loss: 710.9743 	Train_accuracy: 0.7659 	Train_accuracy_1: 0.6731 	Test_loss: 638.2781 	Test_accuracy: 0.7938 	Test_accuracy_1: 0.7116 	Time: 5.7111
Epoch: 6 	Train_loss: 693.1459 	Train_accuracy: 0.7734 	Train_accuracy_1: 0.6943 	Test_loss: 631.4355 	Test_accuracy: 0.7952 	Test_accuracy_1: 0.7247 	Time: 5.6756
Epoch: 7 	Train_