In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from sklearn.model_selection import train_test_split
import random


  from .autonotebook import tqdm as notebook_tqdm


In [19]:
import networkx as nx
from torch_geometric.utils import from_networkx, to_networkx
import copy

In [2]:

# -------------------------------
# CONFIG
# -------------------------------
TOTAL_NODES = 26  # Size of node space (from G)
HIDDEN_DIM = 64
EPOCHS = 50
LEARNING_RATE = 0.01


class GCNEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        return self.conv2(x, edge_index)

class EdgeDecoder(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.linear = nn.Linear(in_channels * 2, 1)

    def forward(self, z, edge_index):
        src, dst = edge_index
        edge_feats = torch.cat([z[src], z[dst]], dim=1)
        return self.linear(edge_feats).squeeze()

class GraphCompletionModel(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.encoder = GCNEncoder(in_channels, hidden_channels)
        self.decoder = EdgeDecoder(hidden_channels)

    def forward(self, x, edge_index, candidate_edges):
        z = self.encoder(x, edge_index)
        scores = self.decoder(z, candidate_edges)
        return scores


In [3]:


def sample_non_edges(num_nodes, existing_edges, num_samples):
    existing_set = set(existing_edges)
    all_possible = [(i, j) for i in range(num_nodes) for j in range(num_nodes) if i != j]
    candidates = list(set(all_possible) - existing_set)
    return random.sample(candidates, min(num_samples, len(candidates)))

def compute_accuracy(scores, labels, threshold=0.5):
    preds = (torch.sigmoid(scores) > threshold).float()
    correct = (preds == labels).sum().item()
    return correct / len(labels)


In [4]:
def prepare_supervised_data(G_prime_list, G_double_prime_LOL, total_nodes):
    data = []
    for i in range(len(G_prime_list)):
        G_prime = G_prime_list[i]
        G_double_primes = G_double_prime_LOL[i]

        true_edges = list(map(tuple, G_prime.edge_index.t().tolist()))

        for G_double_prime in G_double_primes:
            observed_edges = list(map(tuple, G_double_prime.edge_index.t().tolist()))
            positive_edges = [e for e in true_edges if e not in observed_edges]
            negative_edges = sample_non_edges(total_nodes, true_edges, len(positive_edges))

            data.append((G_double_prime, positive_edges, negative_edges))
    return data



In [5]:
def train_model(model, train_data, test_data, total_nodes, epochs=50, lr=0.01, device='cpu'):
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for G_double_prime, pos_edges, neg_edges in train_data:
            x = torch.eye(total_nodes).to(device)
            edge_index = G_double_prime.edge_index.to(device)
            candidate_edges = torch.tensor(pos_edges + neg_edges, dtype=torch.long).t().contiguous().to(device)
            labels = torch.tensor([1]*len(pos_edges) + [0]*len(neg_edges), dtype=torch.float).to(device)

            optimizer.zero_grad()
            scores = model(x, edge_index, candidate_edges)
            loss = F.binary_cross_entropy_with_logits(scores, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        # Evaluation
        model.eval()
        with torch.no_grad():
            total_val_loss = 0
            total_val_acc = 0
            total_samples = 0

            for G_double_prime, pos_edges, neg_edges in test_data:
                x = torch.eye(total_nodes).to(device)
                edge_index = G_double_prime.edge_index.to(device)
                candidate_edges = torch.tensor(pos_edges + neg_edges, dtype=torch.long).t().contiguous().to(device)
                labels = torch.tensor([1]*len(pos_edges) + [0]*len(neg_edges), dtype=torch.float).to(device)

                scores = model(x, edge_index, candidate_edges)
                val_loss = F.binary_cross_entropy_with_logits(scores, labels)
                total_val_loss += val_loss.item()
                total_val_acc += compute_accuracy(scores, labels) * len(labels)
                total_samples += len(labels)

            print(f"[Epoch {epoch+1}] Train Loss: {total_loss:.4f} | Val Loss: {total_val_loss:.4f} | Val Acc: {total_val_acc/total_samples:.4f}")



In [6]:

def run_pipeline(G, G_prime_list, G_double_prime_LOL):
    data = prepare_supervised_data(G_prime_list, G_double_prime_LOL, TOTAL_NODES)
    train_set, test_set = train_test_split(data, test_size=0.2, random_state=42)

    model = GraphCompletionModel(in_channels=TOTAL_NODES, hidden_channels=HIDDEN_DIM)
    train_model(model, train_set, test_set, TOTAL_NODES, epochs=EPOCHS, lr=LEARNING_RATE)

    return model


In [7]:
import pickle as pkl

with open("Main_graph_withNodeFeats.pkl", "rb") as f:
    G = pkl.load(f)

In [13]:
def generate_connected_subgraphs(G, k, n, seed=None):
    if seed is not None:
        random.seed(seed)

    if G.number_of_nodes() <= k:
        raise ValueError("Cannot remove more nodes than exist in the graph.")

    subgraphs = []
    attempts = 0
    max_attempts = 100 * n  # safety to avoid infinite loops

    while len(subgraphs) < n and attempts < max_attempts:
        attempts += 1
        nodes_to_remove = random.sample(list(G.nodes()), k)
        G_sub = G.copy()
        G_sub.remove_nodes_from(nodes_to_remove)

        if nx.is_weakly_connected(G_sub):
            subgraphs.append(G_sub)

    return subgraphs

In [17]:
graph_ls = []
subgraph_ls = []
for k in range(5):
    subgraphs= generate_connected_subgraphs(G, k, n=10, seed=123)
    subgraph_ls.append(subgraphs)

subgraph_ls = [g for graphs in subgraph_ls for g in graphs]
graph_data_obj_ls = []
for nx_graph in subgraph_ls:
    graph_data_obj = from_networkx(nx_graph)
    graph_data_obj_ls.append(graph_data_obj)


In [20]:
subgraph_data_obj_ls = []

for data in graph_data_obj_ls:
    num_edges = data.edge_index.size(1)
    masked_graphs_per_data = []  # inner list for each data graph

    for edges_to_remove in range(1, 6):  # from 1 to 5
        for _ in range(15):  # generate 15 graphs per mask level
            if num_edges <= edges_to_remove:
                continue  # can't remove more edges than exist

            data_copy = copy.deepcopy(data)
            edge_indices = list(range(num_edges))
            to_remove = random.sample(edge_indices, edges_to_remove)

            mask = torch.ones(num_edges, dtype=torch.bool)
            mask[to_remove] = False

            data_copy.edge_index = data.edge_index[:, mask]

            if hasattr(data, 'edge_attr') and data.edge_attr is not None:
                data_copy.edge_attr = data.edge_attr[mask]

            masked_graphs_per_data.append(data_copy)

    subgraph_data_obj_ls.append(masked_graphs_per_data)


In [21]:
run_pipeline(G=G, G_prime_list=graph_data_obj_ls, G_double_prime_LOL=subgraph_data_obj_ls)

[Epoch 1] Train Loss: 1624.5221 | Val Loss: 379.9838 | Val Acc: 0.7449
[Epoch 2] Train Loss: 1478.9567 | Val Loss: 357.8744 | Val Acc: 0.7705
[Epoch 3] Train Loss: 1430.2219 | Val Loss: 345.0938 | Val Acc: 0.7827
[Epoch 4] Train Loss: 1399.7146 | Val Loss: 339.1912 | Val Acc: 0.7984
[Epoch 5] Train Loss: 1380.0853 | Val Loss: 332.1214 | Val Acc: 0.8034
[Epoch 6] Train Loss: 1364.8964 | Val Loss: 329.9406 | Val Acc: 0.8087
[Epoch 7] Train Loss: 1354.4553 | Val Loss: 324.6808 | Val Acc: 0.8110
[Epoch 8] Train Loss: 1343.8439 | Val Loss: 325.0956 | Val Acc: 0.8076
[Epoch 9] Train Loss: 1335.9669 | Val Loss: 323.9693 | Val Acc: 0.8138
[Epoch 10] Train Loss: 1331.9361 | Val Loss: 323.1970 | Val Acc: 0.8105
[Epoch 11] Train Loss: 1326.0163 | Val Loss: 323.3902 | Val Acc: 0.8039
[Epoch 12] Train Loss: 1318.2781 | Val Loss: 327.3375 | Val Acc: 0.8094
[Epoch 13] Train Loss: 1315.2394 | Val Loss: 326.5176 | Val Acc: 0.8079
[Epoch 14] Train Loss: 1310.7018 | Val Loss: 328.9232 | Val Acc: 0.8110
[

GraphCompletionModel(
  (encoder): GCNEncoder(
    (conv1): GCNConv(26, 64)
    (conv2): GCNConv(64, 64)
  )
  (decoder): EdgeDecoder(
    (linear): Linear(in_features=128, out_features=1, bias=True)
  )
)