In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from sklearn.model_selection import train_test_split
import random


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

# -------------------------------
# CONFIG
# -------------------------------
TOTAL_NODES = 26  # Size of node space (from G)
HIDDEN_DIM = 64
EPOCHS = 50
LEARNING_RATE = 0.01


class GCNEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        return self.conv2(x, edge_index)

class EdgeDecoder(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.linear = nn.Linear(in_channels * 2, 1)

    def forward(self, z, edge_index):
        src, dst = edge_index
        edge_feats = torch.cat([z[src], z[dst]], dim=1)
        return self.linear(edge_feats).squeeze()

class GraphCompletionModel(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.encoder = GCNEncoder(in_channels, hidden_channels)
        self.decoder = EdgeDecoder(hidden_channels)

    def forward(self, x, edge_index, candidate_edges):
        z = self.encoder(x, edge_index)
        scores = self.decoder(z, candidate_edges)
        return scores


In [3]:


def sample_non_edges(num_nodes, existing_edges, num_samples):
    existing_set = set(existing_edges)
    all_possible = [(i, j) for i in range(num_nodes) for j in range(num_nodes) if i != j]
    candidates = list(set(all_possible) - existing_set)
    return random.sample(candidates, min(num_samples, len(candidates)))

def compute_accuracy(scores, labels, threshold=0.5):
    preds = (torch.sigmoid(scores) > threshold).float()
    correct = (preds == labels).sum().item()
    return correct / len(labels)


In [4]:
def prepare_supervised_data(G_prime_list, G_double_prime_LOL, total_nodes):
    data = []
    for i in range(len(G_prime_list)):
        G_prime = G_prime_list[i]
        G_double_primes = G_double_prime_LOL[i]

        true_edges = list(map(tuple, G_prime.edge_index.t().tolist()))

        for G_double_prime in G_double_primes:
            observed_edges = list(map(tuple, G_double_prime.edge_index.t().tolist()))
            positive_edges = [e for e in true_edges if e not in observed_edges]
            negative_edges = sample_non_edges(total_nodes, true_edges, len(positive_edges))

            data.append((G_double_prime, positive_edges, negative_edges))
    return data



In [5]:
def train_model(model, train_data, test_data, total_nodes, epochs=50, lr=0.01, device='cpu'):
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for G_double_prime, pos_edges, neg_edges in train_data:
            x = torch.eye(total_nodes).to(device)
            edge_index = G_double_prime.edge_index.to(device)
            candidate_edges = torch.tensor(pos_edges + neg_edges, dtype=torch.long).t().contiguous().to(device)
            labels = torch.tensor([1]*len(pos_edges) + [0]*len(neg_edges), dtype=torch.float).to(device)

            optimizer.zero_grad()
            scores = model(x, edge_index, candidate_edges)
            loss = F.binary_cross_entropy_with_logits(scores, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        # Evaluation
        model.eval()
        with torch.no_grad():
            total_val_loss = 0
            total_val_acc = 0
            total_samples = 0

            for G_double_prime, pos_edges, neg_edges in test_data:
                x = torch.eye(total_nodes).to(device)
                edge_index = G_double_prime.edge_index.to(device)
                candidate_edges = torch.tensor(pos_edges + neg_edges, dtype=torch.long).t().contiguous().to(device)
                labels = torch.tensor([1]*len(pos_edges) + [0]*len(neg_edges), dtype=torch.float).to(device)

                scores = model(x, edge_index, candidate_edges)
                val_loss = F.binary_cross_entropy_with_logits(scores, labels)
                total_val_loss += val_loss.item()
                total_val_acc += compute_accuracy(scores, labels) * len(labels)
                total_samples += len(labels)

            print(f"[Epoch {epoch+1}] Train Loss: {total_loss:.4f} | Val Loss: {total_val_loss:.4f} | Val Acc: {total_val_acc/total_samples:.4f}")



In [6]:

def run_pipeline(G, G_prime_list, G_double_prime_LOL):
    data = prepare_supervised_data(G_prime_list, G_double_prime_LOL, TOTAL_NODES)
    train_set, test_set = train_test_split(data, test_size=0.2, random_state=42)

    model = GraphCompletionModel(in_channels=TOTAL_NODES, hidden_channels=HIDDEN_DIM)
    train_model(model, train_set, test_set, TOTAL_NODES, epochs=EPOCHS, lr=LEARNING_RATE)

    return model


In [7]:
import pickle as pkl

with open("Main_graph_withNodeFeats.pkl", "rb") as f:
    G = pkl.load(f)

FileNotFoundError: [Errno 2] No such file or directory: 'Main_graph_withNodeFeats.pkl'