In [61]:
import numpy as np
import torch
from torch_geometric.data import Data
import networkx as nx
from pyvis.network import Network

In [62]:
import random
import copy
from torch_geometric.nn import GCNConv, GATConv
import torch.nn as nn

In [63]:
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
from torch_geometric.utils import to_networkx


In [64]:
inc_matrix_aug = np.loadtxt("Aug_inc_matrix")

In [65]:
inc_matrix_aug = inc_matrix_aug.reshape(-1,50)

In [66]:
inc_matrix_aug.shape

(45, 50)

In [67]:
num_nodes, num_edges = inc_matrix_aug.shape

# --- Step 2: Convert to edge_index for PyG (multi-edges allowed) ---
edge_list = []
for j in range(num_edges):
    col = inc_matrix_aug[:, j]
    src = np.where(col == -1)[0]
    dst = np.where(col == 1)[0]
    if len(src) == 1 and len(dst) == 1:
        edge_list.append((src[0], dst[0]))  # directed edge

edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()  # shape [2, num_edges]
x = torch.eye(45, dtype=torch.float)

# --- Step 3: Create PyG Data object ---
data_inp= Data(x=x, edge_index=edge_index)

# --- Step 4: Visualize with Pyvis ---
# G = nx.MultiDiGraph()
# edge_tuples = edge_index.t().tolist()
# G.add_edges_from(edge_tuples)

# # Assign label, color, and tooltip (identity vector)
# for node in G.nodes():
#     G.nodes[node]["label"] = str(node)
#     G.nodes[node]["title"] = f"Feature: {x[node].tolist()}"
#     G.nodes[node]["color"] = "green" if node < 26 else "blue"

# # Create Pyvis graph
# net = Network(height='600px', width='100%', directed=True, notebook=True)
# net.from_nx(G)
# net.save_graph("incidence_multigraph.html")


In [68]:
len(edge_list)

50

In [69]:
def generate_connected_subgraphs(G, k, n, seed=None):
    if seed is not None:
        random.seed(seed)

    if G.number_of_nodes() <= k:
        raise ValueError("Cannot remove more nodes than exist in the graph.")

    subgraphs = []
    attempts = 0
    max_attempts = 100 * n  # safety to avoid infinite loops

    while len(subgraphs) < n and attempts < max_attempts:
        attempts += 1
        nodes_to_remove = random.sample(list(G.nodes()), k)
        G_sub = G.copy()
        G_sub.remove_nodes_from(nodes_to_remove)

        if nx.is_weakly_connected(G_sub):
            subgraphs.append(G_sub)

    return subgraphs

In [70]:
def pyg_data_to_nx_multigraph(data):
    G = nx.MultiDiGraph()

    # Step 1: Add all nodes with features
    for i in range(data.num_nodes):
        G.add_node(i, x=data.x[i].tolist())  # attach node features

    # Step 2: Add all edges (with support for multiple edges)
    edge_list = data.edge_index.t().tolist()
    G.add_edges_from(edge_list)

    return G
G = pyg_data_to_nx_multigraph(data=data_inp)

In [71]:
graph_data_obj_ls = []
subgraph_ls = []
for k in range(5):
    subgraphs = generate_connected_subgraphs(G, k, n=10, seed=123)
    subgraph_ls.extend(subgraphs)

for nx_graph in subgraph_ls:
    # Get all edges with duplicates preserved
    edge_list = [(u, v) for u, v, _ in nx_graph.edges(keys=True)]
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()

    # Build identity features using original node indices
    all_nodes = list(nx_graph.nodes())
    max_node_id = max(all_nodes)
    x = torch.eye(max_node_id + 1)  # size = [max_node_id + 1, max_node_id + 1]

    # Some nodes might be missing → subset x to only the active node set
    node_mask = torch.zeros_like(x)
    for node in all_nodes:
        node_mask[node] = x[node]
    x_subset = node_mask  # shape = [max_node_id + 1, feature_dim]

    data = Data(x=x_subset, edge_index=edge_index)
    graph_data_obj_ls.append(data)



In [72]:
subgraph_data_obj_ls = []

for data in graph_data_obj_ls:
    G_nx = to_networkx(data, to_undirected=False)
    incidence_matrix = nx.incidence_matrix(G_nx, oriented=True).toarray()
    rank = np.linalg.matrix_rank(incidence_matrix)
    num_edges = data.edge_index.size(1)
    masked_graphs_per_data = []  # inner list for each data graph

    for edges_to_remove in range(rank, min(rank + 6, num_edges)):  # from 1 to 5
        for _ in range(15):  # generate 15 graphs per mask level
            if num_edges <= edges_to_remove:
                continue  # can't remove more edges than exist

            data_copy = copy.deepcopy(data)
            edge_indices = list(range(num_edges))
            to_remove = random.sample(edge_indices, edges_to_remove)

            mask = torch.ones(num_edges, dtype=torch.bool)
            mask[to_remove] = False

            data_copy.edge_index = data.edge_index[:, mask]

            if hasattr(data, 'edge_attr') and data.edge_attr is not None:
                data_copy.edge_attr = data.edge_attr[mask]

            masked_graphs_per_data.append(data_copy)

    subgraph_data_obj_ls.append(masked_graphs_per_data)


In [73]:

# -------------------------------
# CONFIG
# -------------------------------
TOTAL_NODES = 45  # Size of node space (from G)
HIDDEN_DIM1 = 64
HIDDEN_DIM2 = 128
EPOCHS = 20
LEARNING_RATE = 0.01


class GCNEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels)
        self.conv2 = GATConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        return self.conv2(x, edge_index)

class EdgeDecoder(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.linear1 = nn.Linear(in_channels * 2, 64)
        self.linear2 = nn.Linear(64, 1)
        #self.relu = F.relu()

    def forward(self, z, edge_index):
        src, dst = edge_index
        edge_feats = torch.cat([z[src], z[dst]], dim=1)
        edge_feats = F.relu(self.linear1(edge_feats))

        return self.linear2(edge_feats).squeeze()

class GraphCompletionModel(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.encoder = GCNEncoder(in_channels, hidden_channels)
        self.decoder = EdgeDecoder(hidden_channels)

    def forward(self, x, edge_index, candidate_edges):
        z = self.encoder(x, edge_index)
        scores = self.decoder(z, candidate_edges)
        return scores


In [74]:
def sample_non_edges(num_nodes, existing_edges, num_samples):
    existing_set = set(existing_edges)
    all_possible = [(i, j) for i in range(num_nodes) for j in range(num_nodes) if i != j]
    candidates = list(set(all_possible) - existing_set)
    return random.sample(candidates, min(num_samples, len(candidates)))

def compute_accuracy(scores, labels, threshold=0.5):
    preds = (torch.sigmoid(scores) > threshold).float()
    correct = (preds == labels).sum().item()
    return correct / len(labels)

In [75]:
def prepare_supervised_data(G_prime_list, G_double_prime_LOL, total_nodes):
    data = []
    for i in range(len(G_prime_list)):
        G_prime = G_prime_list[i]
        G_double_primes = G_double_prime_LOL[i]

        true_edges = list(map(tuple, G_prime.edge_index.t().tolist()))

        for G_double_prime in G_double_primes:
            observed_edges = list(map(tuple, G_double_prime.edge_index.t().tolist()))
            positive_edges = [e for e in true_edges if e not in observed_edges]
            negative_edges = sample_non_edges(total_nodes, true_edges, len(positive_edges))

            data.append((G_double_prime, positive_edges, negative_edges))
    return data


In [76]:
def train_model(model, train_data, test_data, total_nodes, epochs=20, lr=0.01, device=torch.device('cuda')):
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for G_double_prime, pos_edges, neg_edges in train_data:
            x = torch.eye(total_nodes).to(device)
            edge_index = G_double_prime.edge_index.to(device)
            candidate_edges = torch.tensor(pos_edges + neg_edges, dtype=torch.long).t().contiguous().to(device)
            labels = torch.tensor([1]*len(pos_edges) + [0]*len(neg_edges), dtype=torch.float).to(device)

            optimizer.zero_grad()
            scores = model(x, edge_index, candidate_edges)
            loss = F.binary_cross_entropy_with_logits(scores, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        # Evaluation
        model.eval()
        with torch.no_grad():
            total_val_loss = 0
            total_val_acc = 0
            total_samples = 0

            for G_double_prime, pos_edges, neg_edges in test_data:
                x = torch.eye(total_nodes).to(device)
                edge_index = G_double_prime.edge_index.to(device)
                
                candidate_edges = torch.tensor(pos_edges + neg_edges, dtype=torch.long).t().contiguous().to(device)
                labels = torch.tensor([1]*len(pos_edges) + [0]*len(neg_edges), dtype=torch.float).to(device)

                scores = model(x, edge_index, candidate_edges)
                probs = torch.sigmoid(scores)
                val_loss = F.binary_cross_entropy_with_logits(scores, labels)
                total_val_loss += val_loss.item()
                total_val_acc += compute_accuracy(scores, labels) * len(labels)
                total_samples += len(labels)
        
    

            print(f"[Epoch {epoch+1}] Train Loss: {total_loss:.4f} | Val Loss: {total_val_loss:.4f} | Val Acc: {total_val_acc/total_samples:.4f}")

    return probs


In [77]:
def run_pipeline(G, G_prime_list, G_double_prime_LOL):
    data = prepare_supervised_data(G_prime_list, G_double_prime_LOL, TOTAL_NODES)
    train_set, test_set = train_test_split(data, test_size=0.2, random_state=42)

    model = GraphCompletionModel(in_channels=TOTAL_NODES, hidden_channels=HIDDEN_DIM1)
    probs = train_model(model, train_set, test_set, TOTAL_NODES, epochs=EPOCHS, lr=LEARNING_RATE)

    return model, probs

In [78]:
model, probs = run_pipeline(G=G, G_prime_list=graph_data_obj_ls, G_double_prime_LOL=subgraph_data_obj_ls)

[Epoch 1] Train Loss: 193.0931 | Val Loss: 12.5008 | Val Acc: 0.9963
[Epoch 2] Train Loss: 75.8368 | Val Loss: 12.8246 | Val Acc: 0.9961
[Epoch 3] Train Loss: 85.5259 | Val Loss: 21.6426 | Val Acc: 0.9952
[Epoch 4] Train Loss: 104.5439 | Val Loss: 20.5525 | Val Acc: 0.9947
[Epoch 5] Train Loss: 96.6912 | Val Loss: 16.5190 | Val Acc: 0.9958
[Epoch 6] Train Loss: 97.6251 | Val Loss: 23.1837 | Val Acc: 0.9925
[Epoch 7] Train Loss: 91.2201 | Val Loss: 24.0647 | Val Acc: 0.9935
[Epoch 8] Train Loss: 97.6252 | Val Loss: 15.9194 | Val Acc: 0.9958
[Epoch 9] Train Loss: 99.0153 | Val Loss: 21.4104 | Val Acc: 0.9951
[Epoch 10] Train Loss: 89.5823 | Val Loss: 53.9266 | Val Acc: 0.9848
[Epoch 11] Train Loss: 101.5845 | Val Loss: 19.2023 | Val Acc: 0.9953
[Epoch 12] Train Loss: 117.5365 | Val Loss: 51.7023 | Val Acc: 0.9874
[Epoch 13] Train Loss: 147.8240 | Val Loss: 33.6764 | Val Acc: 0.9910
[Epoch 14] Train Loss: 140.1746 | Val Loss: 27.1940 | Val Acc: 0.9926
[Epoch 15] Train Loss: 127.0778 | Val

In [79]:
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

def evaluate_graph_completion(model, G_prime_list, G_double_prime_LOL, total_nodes=45, threshold=0.5, device=torch.device('cuda')):
    all_results = []
    model.eval()

    for i, G_prime in enumerate(G_prime_list):
        edges_G_prime = set(map(tuple, G_prime.edge_index.t().tolist()))

        for G_double_prime in G_double_prime_LOL[i]:
            edges_G_double_prime = set(map(tuple, G_double_prime.edge_index.t().tolist()))
            true_missing_edges = list(edges_G_prime - edges_G_double_prime)

            # Evaluate only on these ground-truth missing edges
            candidate_edges = true_missing_edges

            if len(candidate_edges) == 0:
                continue  # skip if no missing edges

            candidate_edges_tensor = torch.tensor(candidate_edges, dtype=torch.long).t().contiguous().to(device)

            x = torch.eye(total_nodes).to(device)
            edge_index = G_double_prime.edge_index.to(device)
            with torch.no_grad():
                probs = torch.sigmoid(model(x, edge_index, candidate_edges_tensor)).cpu()

            # Predict only over true missing edges
            y_true = [1] * len(candidate_edges)
            y_pred = [1 if p > threshold else 0 for p in probs]

            predicted_edges = [candidate_edges[i] for i, p in enumerate(probs) if p > threshold]

            precision = precision_score(y_true, y_pred, zero_division=0)
            recall = recall_score(y_true, y_pred, zero_division=0)
            f1 = f1_score(y_true, y_pred, zero_division=0)
            acc = accuracy_score(y_true, y_pred)

            result = {
                'G_index': i,
                'correct_predictions': len(predicted_edges),  # since only evaluated on true edges
                'precision': precision,
                'recall': recall,
                'f1_score': f1,
                'accuracy': acc,
                'num_predicted': len(predicted_edges),
                'num_true_missing': len(true_missing_edges)
            }

            all_results.append(result)

    return all_results


all_results = evaluate_graph_completion(model, graph_data_obj_ls, subgraph_data_obj_ls)

In [84]:
for result in all_results:
    correct_pred = result['correct_predictions']
    #incorrect_pred = result['incorrect_predictions']
    num_masked = result['num_true_missing']
    num_predicted = result['num_predicted']
    print(f'correct_pred = {correct_pred}, num_masked = {num_masked}, num_predicted = {num_predicted}')

correct_pred = 44, num_masked = 44, num_predicted = 44
correct_pred = 44, num_masked = 44, num_predicted = 44
correct_pred = 44, num_masked = 44, num_predicted = 44
correct_pred = 44, num_masked = 44, num_predicted = 44
correct_pred = 44, num_masked = 44, num_predicted = 44
correct_pred = 44, num_masked = 44, num_predicted = 44
correct_pred = 44, num_masked = 44, num_predicted = 44
correct_pred = 43, num_masked = 44, num_predicted = 43
correct_pred = 44, num_masked = 44, num_predicted = 44
correct_pred = 44, num_masked = 44, num_predicted = 44
correct_pred = 44, num_masked = 44, num_predicted = 44
correct_pred = 44, num_masked = 44, num_predicted = 44
correct_pred = 44, num_masked = 44, num_predicted = 44
correct_pred = 44, num_masked = 44, num_predicted = 44
correct_pred = 44, num_masked = 44, num_predicted = 44
correct_pred = 45, num_masked = 45, num_predicted = 45
correct_pred = 45, num_masked = 45, num_predicted = 45
correct_pred = 45, num_masked = 45, num_predicted = 45
correct_pr

In [81]:
def visualize_incidence_with_pyvis(inc_matrix, output_file="incidence_multigraph2.html"):
    num_nodes, num_streams = inc_matrix.shape
    G = nx.DiGraph()

    # 1. Add nodes with appropriate labels (ensure all node IDs are int)
    for i in range(num_nodes):
        node_id = int(i)
        label = f"Node_{i}" if i < 26 else f"Extra_{i}"
        G.add_node(node_id, label=label)

    # 2. Add edges based on incidence matrix
    for j in range(num_streams):
        col = inc_matrix[:, j]
        src_indices = np.where(col == 1)[0]
        dst_indices = np.where(col == -1)[0]

        if len(src_indices) != 1 or len(dst_indices) != 1:
            print(f"Skipping stream {j}: must have exactly 1 source and 1 sink.")
            continue

        src = int(src_indices[0])
        dst = int(dst_indices[0])
        G.add_edge(src, dst, label=f"{src}-{dst}")

    # 3. Create and display Pyvis graph
    net = Network(height='700px', width='100%', directed=True, notebook=False)
    net.from_nx(G)
    net.write_html(output_file)

visualize_incidence_with_pyvis(inc_matrix_aug)

In [82]:
def visualize_node_removal(G, G_prime, output_file="node_removal2.html"):
    G_nx = nx.DiGraph()

    full_nodes = set(range(G.num_nodes))

    # Extract actual node indices present in G_prime from edge_index
    edge_idx = G_prime.edge_index
    sub_nodes = set(edge_idx[0].tolist() + edge_idx[1].tolist())

    # 1. Add nodes with color
    for node in full_nodes:
        color = "lightgreen" if node in sub_nodes else "red"
        G_nx.add_node(node, label=f"Node_{node}", color=color)

    # 2. Add edges from G
    for src, dst in G.edge_index.t().tolist():
        G_nx.add_edge(src, dst)

    net = Network(height="700px", width="100%", directed=True)
    net.from_nx(G_nx)
    net.write_html(output_file)

visualize_node_removal(data_inp, graph_data_obj_ls[12])

In [83]:
def visualize_edge_masking(G_prime, G_double_prime, positive_edges, negative_edges, output_file="edge_masking.html"):
    G_nx = nx.DiGraph()
    all_nodes = set(range(G_prime.num_nodes))

    # Add nodes (same color)
    for node in all_nodes:
        G_nx.add_node(node, label=f"Node_{node}", color="lightblue")

    # Add edges from G_double_prime (observed edges)
    observed_edges = set(map(tuple, G_double_prime.edge_index.t().tolist()))
    for src, dst in observed_edges:
        G_nx.add_edge(src, dst, color="black", width=2)

    # Add masked edges (positive edges not in observed)
    for src, dst in positive_edges:
        if (src, dst) not in observed_edges:
            G_nx.add_edge(src, dst, color="gray", dashes=True, title="Masked Edge")

    # Add negative edges
    for src, dst in negative_edges:
        G_nx.add_edge(src, dst, color="red", width=2, title="Negative Edge")

    net = Network(height="700px", width="100%", directed=True)
    net.from_nx(G_nx)
    net.write_html(output_file)

supervised_data = prepare_supervised_data(graph_data_obj_ls, subgraph_data_obj_ls, total_nodes=data_inp.num_nodes)
G_double_prime, pos_edges, neg_edges = supervised_data[0]
visualize_edge_masking(graph_data_obj_ls[0], G_double_prime, pos_edges, neg_edges, "Gprime_vs_Gdoubleprime_0.html")