In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import train_test_split_edges, negative_sampling, to_dense_adj

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, average_precision_score
from scipy.linalg import sqrtm
import os

# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)

In [2]:
# 1. 加载数据
def load_data():
    dataset = Planetoid(root='data/Cora', name='Cora')
    data = dataset[0]
    data = train_test_split_edges(data)
    
    return data, dataset.num_node_features, dataset.num_classes

In [3]:
# 2. 定义生成器网络 - 生成节点嵌入
class Generator(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.3):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.LeakyReLU(0.2), nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim*2), nn.LeakyReLU(0.2), nn.Dropout(dropout),
            nn.Linear(hidden_dim*2, hidden_dim*4), nn.LeakyReLU(0.2), nn.Dropout(dropout),
            nn.Linear(hidden_dim*4, hidden_dim*2), nn.LeakyReLU(0.2), nn.Dropout(dropout),
            nn.Linear(hidden_dim*2, hidden_dim), nn.LeakyReLU(0.2), nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim), 
            nn.Tanh()   # 将生成的嵌入归一化到[-1, 1]范围
        )
        
    def forward(self, z): 
        return self.model(z)

In [4]:
# 3. 定义判别器网络 - 判别节点对是否有边连接
class Discriminator(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, dropout=0.3):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim*2, hidden_dim), nn.LeakyReLU(0.2), nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim*2), nn.LeakyReLU(0.2), nn.Dropout(dropout),
            nn.Linear(hidden_dim*2, hidden_dim*4), nn.LeakyReLU(0.2), nn.Dropout(dropout),
            nn.Linear(hidden_dim*4, hidden_dim*2), nn.LeakyReLU(0.2), nn.Dropout(dropout),
            nn.Linear(hidden_dim*2, hidden_dim), nn.LeakyReLU(0.2), nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1), 
            nn.Sigmoid()    #输出二分类概率
        )
        
    def forward(self, x1, x2): 
        # 拼接两个节点的嵌入
        return self.model(torch.cat([x1, x2], dim=1))

In [5]:
# 4. 计算FID指标（适配图嵌入）
def calculate_fid(real_features, generated_features):
    # 计算均值和协方差矩阵
    mu1, sigma1 = np.mean(real_features, 0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = np.mean(generated_features, 0), np.cov(generated_features, rowvar=False)
    
    # 计算均值差的平方和
    diff = np.sum((mu1 - mu2)** 2)
    
    # 计算协方差矩阵的平方根
    covmean, _ = sqrtm(sigma1.dot(sigma2), disp=False)
    
    # 数值稳定性处理
    if not np.isfinite(covmean).all():
        covmean = sqrtm((sigma1 + 1e-6*np.eye(sigma1.shape[0])).dot(sigma2 + 1e-6*np.eye(sigma2.shape[0])))
        
    # 如果存在虚数部分，取实部    
    if np.iscomplexobj(covmean): 
        covmean = covmean.real
    
    return diff + np.trace(sigma1 + sigma2 - 2 * covmean)

In [6]:
# 5. 判别器训练函数
def train_discriminator(generator, discriminator, optimizer_D, criterion, 
                        real_features, z, pos_x1, pos_x2, neg_x1, neg_x2, device):
    
    optimizer_D.zero_grad()
    fake_features = generator(z)
    batch_size = z.size(0)
    idx1, idx2 = torch.randperm(batch_size), torch.randperm(batch_size)
    fake_x1, fake_x2 = fake_features[idx1], fake_features[idx2]
    real_pred = discriminator(pos_x1, pos_x2)
    fake_pred = discriminator(fake_x1, fake_x2)
    real_loss = criterion(real_pred, torch.ones_like(real_pred))
    fake_loss = criterion(fake_pred, torch.zeros_like(fake_pred))
    d_loss = (real_loss + fake_loss) / 2
    d_loss.backward()
    optimizer_D.step()
    
    return d_loss.item()

In [7]:
# 6. 生成器训练函数
def train_generator(generator, discriminator, optimizer_G, criterion, z, device):
    
    optimizer_G.zero_grad()
    fake_features = generator(z)
    batch_size = z.size(0)
    idx1, idx2 = torch.randperm(batch_size), torch.randperm(batch_size)
    fake_x1, fake_x2 = fake_features[idx1], fake_features[idx2]
    fake_pred = discriminator(fake_x1, fake_x2)
    g_loss = criterion(fake_pred, torch.ones_like(fake_pred))
    g_loss.backward()
    optimizer_G.step()
    
    return g_loss.item()

In [8]:
# 7. 绘图函数
def plot_training_metrics(g_losses, d_losses, auc_scores, ap_scores, epochs):
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(g_losses, label='Generator Loss')
    plt.plot(d_losses, label='Discriminator Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training Losses')
    
    plt.subplot(1, 2, 2)
    plt.plot(range(0, epochs, 10), auc_scores, label='AUC')
    plt.plot(range(0, epochs, 10), ap_scores, label='AP')
    
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.legend()
    plt.title('Validation Performance')
    plt.tight_layout()
    
    plt.savefig('training_metrics.png')
    
    plt.close()

In [9]:
# 8. 训练函数（分离绘图）
def train_gan(data, num_features, device, epochs=200, batch_size=64, latent_dim=128):
    
    generator = Generator(latent_dim, 128, num_features).to(device)
    discriminator = Discriminator(num_features, 128).to(device)
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
    criterion = nn.BCELoss()
    
    real_features = data.x.to(device)
    g_losses, d_losses, auc_scores, ap_scores = [], [], [], []
    val_pos = data.val_pos_edge_index.to(device)
    val_neg = data.val_neg_edge_index.to(device)
    
    for epoch in range(epochs):
        generator.train(), discriminator.train()
        pos_edge = data.train_pos_edge_index.to(device)
        neg_edge = data.train_neg_edge_index.to(device)
        num_batches = pos_edge.size(1) // batch_size
        epoch_g, epoch_d = 0, 0
        
        for batch in range(num_batches):
            pos_idx = torch.randperm(pos_edge.size(1))[:batch_size]
            neg_idx = torch.randperm(neg_edge.size(1))[:batch_size]
            pos_pairs, neg_pairs = pos_edge[:, pos_idx], neg_edge[:, neg_idx]
            pos_x1, pos_x2 = real_features[pos_pairs[0]], real_features[pos_pairs[1]]
            neg_x1, neg_x2 = real_features[neg_pairs[0]], real_features[neg_pairs[1]]
            z = torch.randn(batch_size, latent_dim).to(device)
            
            d_loss = train_discriminator(generator, discriminator, optimizer_D, criterion,
                                        real_features, z, pos_x1, pos_x2, neg_x1, neg_x2, device)
            g_loss = train_generator(generator, discriminator, optimizer_G, criterion, z, device)
            epoch_d += d_loss
            epoch_g += g_loss
        
        epoch_d, epoch_g = epoch_d/num_batches, epoch_g/num_batches
        g_losses.append(epoch_g)
        d_losses.append(epoch_d)
        
        if (epoch+1) % 10 == 0:
            generator.eval()
            with torch.no_grad():
                z = torch.randn(real_features.size(0), latent_dim).to(device)
                gen_features = generator(z)
                pos_scores = discriminator(gen_features[val_pos[0]], gen_features[val_pos[1]]).squeeze().cpu()
                neg_scores = discriminator(gen_features[val_neg[0]], gen_features[val_neg[1]]).squeeze().cpu()
                y_true = np.concatenate([np.ones_like(pos_scores), np.zeros_like(neg_scores)])
                y_scores = np.concatenate([pos_scores.numpy(), neg_scores.numpy()])
                auc, ap = roc_auc_score(y_true, y_scores), average_precision_score(y_true, y_scores)
                real_np = real_features.cpu().numpy()
                gen_np = gen_features.cpu().numpy()
                fid = calculate_fid(real_np, gen_np)
                print(f'Epoch {epoch+1} | G:{epoch_g:.4f} D:{epoch_d:.4f} | AUC:{auc:.4f} AP:{ap:.4f} FID:{fid:.4f}')
                auc_scores.append(auc)
                ap_scores.append(ap)
    
    return generator, discriminator, g_losses, d_losses, auc_scores, ap_scores

In [10]:
# 9. 生成与评估
def generate_and_evaluate(generator, data, num_features, device, num_samples=5000, latent_dim=128):
    generator.eval()
    gen_features = []
    with torch.no_grad():
        for i in range(0, num_samples, 100):
            batch = min(100, num_samples - i)
            z = torch.randn(batch, latent_dim).to(device)
            gen_features.append(generator(z).cpu().numpy())
    gen_features = np.concatenate(gen_features, 0)[:num_samples]
    np.save('generated_features.npy', gen_features)
    
    real_features = data.x.cpu().numpy()
    fid = calculate_fid(real_features, gen_features)
    
    from sklearn.decomposition import PCA
    pca = PCA(2)
    real_pca = pca.fit_transform(real_features)
    gen_pca = pca.transform(gen_features[:len(real_features)])
    plt.figure(figsize=(10, 8))
    plt.scatter(real_pca[:, 0], real_pca[:, 1], alpha=0.5, label='Real')
    plt.scatter(gen_pca[:, 0], gen_pca[:, 1], alpha=0.5, label='Generated')
    plt.title('PCA Visualization')
    plt.legend()
    plt.savefig('feature_visualization.png')
    
    plt.close()
    return fid

In [11]:
def construct_negative_edges(data, num_neg_samples=None, is_undirected=True):
    """
    为Cora数据集手动构造负边
    
    参数:
        data: Cora数据集的Data对象
        num_neg_samples: 要构造的负边数量，默认与正边数量相同
        is_undirected: 是否为无向图
    """
    # 获取节点数量
    num_nodes = data.num_nodes
    
    # 获取训练集正边索引
    pos_edge_index = data.train_pos_edge_index
    
    # 如果未指定负边数量，默认与正边数量相同
    if num_neg_samples is None:
        num_neg_samples = pos_edge_index.size(1)
    
    # 构造负边（使用PyG的negative_sampling函数）
    # 注意：需传入排除的边索引（即正边），避免采样到已存在的边
    neg_edge_index = negative_sampling(
        edge_index=pos_edge_index,  # 基于现有边结构采样
        num_nodes=num_nodes,        # 节点总数
        num_neg_samples=num_neg_samples,  # 负边数量
        method='sparse'             # 稀疏采样方法，适用于大图
    )
    
    # 如果是无向图，确保负边对称性（可选）
    if is_undirected:
        # 将负边转换为对称形式（i->j 和 j->i 都作为负边）
        neg_edge_index = torch.cat([neg_edge_index, neg_edge_index.flip(0)], dim=1)
    
    # 将构造的负边添加到data对象
    data.train_neg_edge_index = neg_edge_index
    
    return data

In [12]:
# 主函数

def main():
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    data, num_features, _ = load_data()
    data = construct_negative_edges(data)
    data = data.to(device)
    
    gen, dis, g_losses, d_losses, auc_scores, ap_scores = train_gan(data, num_features, device, epochs=200, batch_size=64, latent_dim=128)
    
    fid = generate_and_evaluate(gen, data, num_features, device)
    print(f"Final FID: {fid:.4f}")
    
    plot_training_metrics(g_losses, d_losses, auc_scores, ap_scores, 200)
    os.makedirs('models', exist_ok=True)
    torch.save(gen.state_dict(), 'models/generator.pth')
    torch.save(dis.state_dict(), 'models/discriminator.pth')

In [13]:
if __name__ == '__main__':
    main()

Using device: cuda




Epoch 10 | G:5.0283 D:0.0720 | AUC:0.4511 AP:0.4730 FID:19.4291
Epoch 20 | G:5.8503 D:0.0622 | AUC:0.5109 AP:0.5290 FID:19.4626
Epoch 30 | G:6.7200 D:0.0354 | AUC:0.4662 AP:0.4840 FID:20.2776
Epoch 40 | G:6.7650 D:0.0402 | AUC:0.4835 AP:0.4873 FID:20.3234
Epoch 50 | G:7.3365 D:0.0379 | AUC:0.5770 AP:0.5662 FID:22.7833
Epoch 60 | G:7.5749 D:0.0358 | AUC:0.4925 AP:0.5021 FID:21.5222
Epoch 70 | G:7.5576 D:0.0426 | AUC:0.4801 AP:0.4877 FID:22.4408
Epoch 80 | G:6.9788 D:0.0564 | AUC:0.4963 AP:0.5064 FID:24.4092
Epoch 90 | G:6.5951 D:0.0603 | AUC:0.4709 AP:0.4837 FID:25.6409
Epoch 100 | G:6.5935 D:0.0701 | AUC:0.4562 AP:0.4824 FID:27.6551
Epoch 110 | G:6.4599 D:0.0794 | AUC:0.5008 AP:0.4986 FID:28.7429
Epoch 120 | G:6.3825 D:0.0859 | AUC:0.4815 AP:0.4948 FID:32.9534
Epoch 130 | G:5.8165 D:0.0965 | AUC:0.5066 AP:0.5228 FID:31.9429
Epoch 140 | G:6.3037 D:0.0904 | AUC:0.5413 AP:0.5355 FID:31.2433
Epoch 150 | G:5.6318 D:0.0969 | AUC:0.4759 AP:0.4894 FID:31.4751
Epoch 160 | G:5.6821 D:0.0984 | AU