In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.nn import GATConv, GCNConv
from scipy.sparse import coo_matrix
import numpy as np
from torch_geometric.utils import to_networkx
import random
from heapdict import heapdict
from node2vec import Node2Vec
import argparse
import torch.nn.init as init
from utils import *

In [None]:
edge_dict = {}
# 定义GCN模型
class GCN(torch.nn.Module):
    def __init__(self,num_features, graph_embedding_size):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, 128)
        self.conv2 = GCNConv(128, graph_embedding_size)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        return x

# 定义GAT模型
class GAT(torch.nn.Module):
    def __init__(self, num_features, num_heads=4, graph_embedding_size=256):
        super(GAT, self).__init__()
        self.gat1 = GATConv(num_features, 512, heads=num_heads, dropout=0.2)
        self.gat2 = GATConv(512 * num_heads, graph_embedding_size, heads=1, concat=False, dropout=0.2)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.gat1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.gat2(x, edge_index)
        return F.log_softmax(x, dim=1)

# 定義用來決定edge是否修改的MLP
class MLP(nn.Module):
    def __init__(self, input_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 1)

        init.kaiming_normal_(self.fc1.weight, nonlinearity='relu')
        init.kaiming_normal_(self.fc2.weight, nonlinearity='relu')
        init.kaiming_normal_(self.fc3.weight, nonlinearity='relu')
        init.kaiming_normal_(self.fc4.weight, nonlinearity='relu')
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = torch.sigmoid(self.fc4(x))
        return x
    
class GCN_edge_modify(nn.Module):
    def __init__(self, num_features, hidden_channels = 512):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        # 最后一层，用于产生最终输出
        self.out = nn.Linear(hidden_channels, 1)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.out(x)
        return torch.sigmoid(x)
    
class MLPClassifier(nn.Module):  #最後用來判定graph的result是否有相同的MVC
    def __init__(self, input_size):
        super(MLPClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, 512)  # 第一层
        self.fc2 = nn.Linear(512, 128)          # 第二层
        self.fc3 = nn.Linear(128, 1)           # 输出层

        init.kaiming_normal_(self.fc1.weight, nonlinearity='relu')
        init.kaiming_normal_(self.fc2.weight, nonlinearity='relu')
        init.kaiming_normal_(self.fc3.weight, nonlinearity='relu')
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))        # 使用sigmoid确保输出在0到1之间
        return x

In [None]:
class Node2Edge(nn.Module):
    def __init__(self, num_features, num_heads, graph_embedding_size):  
        self.gat = GAT(num_features=num_features + 1, num_heads=num_heads, graph_embedding_size = graph_embedding_size)
        