In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.sparse import csgraph
from scipy.sparse.linalg import eigsh

In [3]:
from torch_geometric.data import Data
import networkx as nx
import random
import copy
from torch_geometric.utils import to_networkx

  from .autonotebook import tqdm as notebook_tqdm


In [21]:
from torch.utils.data import Dataset, DataLoader

class MaskedGraphDataset(Dataset):
    def __init__(self, graph_data_obj_ls, subgraph_data_obj_ls):
        self.inputs = []
        self.targets = []

        for full_graph, masked_versions in zip(graph_data_obj_ls, subgraph_data_obj_ls):
            for masked_graph in masked_versions:
                self.inputs.append(masked_graph)  # G''
                self.targets.append(full_graph)   # G'

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_graph = self.inputs[idx]
        target_graph = self.targets[idx]

        # Ensure that all tensors are returned
        return {
            'input': input_graph,
            'target': target_graph
        }


In [22]:
inc_matrix_aug = np.loadtxt("Aug_inc_matrix")
inc_matrix_aug = inc_matrix_aug.reshape(-1,50)

num_nodes, num_edges = inc_matrix_aug.shape

# --- Step 2: Convert to edge_index for PyG (multi-edges allowed) ---
edge_list = []
for j in range(num_edges):
    col = inc_matrix_aug[:, j]
    src = np.where(col == -1)[0]
    dst = np.where(col == 1)[0]
    if len(src) == 1 and len(dst) == 1:
        edge_list.append((src[0], dst[0]))  # directed edge

edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()  # shape [2, num_edges]
x = torch.eye(45, dtype=torch.float)

# --- Step 3: Create PyG Data object ---
data_inp= Data(x=x, edge_index=edge_index)


def pyg_data_to_nx_multigraph(data):
    G = nx.MultiDiGraph()

    # Step 1: Add all nodes with features
    for i in range(data.num_nodes):
        G.add_node(i, x=data.x[i].tolist())  # attach node features

    # Step 2: Add all edges (with support for multiple edges)
    edge_list = data.edge_index.t().tolist()
    G.add_edges_from(edge_list)

    return G
G = pyg_data_to_nx_multigraph(data=data_inp)
def generate_connected_subgraphs(G, k, n, seed=None):
    if seed is not None:
        random.seed(seed)

    if G.number_of_nodes() <= k:
        raise ValueError("Cannot remove more nodes than exist in the graph.")

    subgraphs = []
    attempts = 0
    max_attempts = 100 * n  # safety to avoid infinite loops

    while len(subgraphs) < n and attempts < max_attempts:
        attempts += 1
        nodes_to_remove = random.sample(list(G.nodes()), k)
        G_sub = G.copy()
        G_sub.remove_nodes_from(nodes_to_remove)

        if nx.is_weakly_connected(G_sub):
            subgraphs.append(G_sub)

    return subgraphs
graph_data_obj_ls = []
subgraph_ls = []
for k in range(5):
    subgraphs = generate_connected_subgraphs(G, k, n=10, seed=123)
    subgraph_ls.extend(subgraphs)

for nx_graph in subgraph_ls:
    # Get all edges with duplicates preserved
    edge_list = [(u, v) for u, v, _ in nx_graph.edges(keys=True)]
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()

    # Build identity features using original node indices
    all_nodes = list(nx_graph.nodes())
    num_nodes_global = 45
    x = torch.eye(num_nodes_global)  # [45, 45]
    node_mask = torch.zeros_like(x)  # [45, 45]

    for node in all_nodes:
        node_mask[node] = x[node]  # Keep features only for nodes in this subgraph

    x_subset = node_mask

    data = Data(x=x_subset, edge_index=edge_index)
    graph_data_obj_ls.append(data)
subgraph_data_obj_ls = []

for data in graph_data_obj_ls:
    G_nx = to_networkx(data, to_undirected=False)
    incidence_matrix = nx.incidence_matrix(G_nx, oriented=True).toarray()
    rank = np.linalg.matrix_rank(incidence_matrix)
    num_edges = data.edge_index.size(1)
    num_nodes = data.num_nodes

    masked_graphs_per_data = []

    for edges_to_remove in range(rank, min(rank + 6, num_edges)):
        for _ in range(15):
            if num_edges <= edges_to_remove:
                continue

            data_copy = copy.deepcopy(data)

            # -------------------------------
            # 1. Mask edges
            edge_indices = list(range(num_edges))
            to_remove = random.sample(edge_indices, edges_to_remove)
            mask = torch.ones(num_edges, dtype=torch.bool)
            mask[to_remove] = False
            data_copy.edge_index = data.edge_index[:, mask]

            if hasattr(data, 'edge_attr') and data.edge_attr is not None:
                data_copy.edge_attr = data.edge_attr[mask]

            # -------------------------------
            # 2. Mask nodes (retain ~90% randomly)
            node_mask = torch.ones(45, dtype=torch.bool)
            total_nodes = 45
            num_nodes_to_mask = int(0.1 * total_nodes)
            nodes_to_mask = random.sample(range(45), num_nodes_to_mask)
            node_mask[nodes_to_mask] = False

            data_copy.masked_nodes = node_mask  # Add this attribute to use later

            masked_graphs_per_data.append(data_copy)

    subgraph_data_obj_ls.append(masked_graphs_per_data)

In [23]:
from torch.utils.data import random_split

# Create full dataset
full_dataset = MaskedGraphDataset(graph_data_obj_ls, subgraph_data_obj_ls)

# Split: 80% train, 20% validation
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

# PyTorch Geometric uses a custom collate_fn
from torch_geometric.loader import DataLoader as PyGDataLoader

train_loader = PyGDataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = PyGDataLoader(val_dataset, batch_size=8, shuffle=False)


In [24]:
import torch.nn as nn
from torch_geometric.nn import TransformerConv
from torch_geometric.utils import to_dense_batch

class GraphTransformer(nn.Module):
    def __init__(self, num_nodes=45, in_dim=45, hidden_dim=128, num_heads=4, num_layers=3):
        super(GraphTransformer, self).__init__()
        self.num_nodes = num_nodes
        self.in_dim = in_dim
        self.hidden_dim = hidden_dim

        # Initial projection (optional: can be identity if x is one-hot)
        self.input_proj = nn.Linear(in_dim, hidden_dim)

        # Stack of TransformerConv layers
        self.transformer_layers = nn.ModuleList([
            TransformerConv(hidden_dim, hidden_dim // num_heads, heads=num_heads, dropout=0.1)
            for _ in range(num_layers)
        ])

        # Output projections
        self.node_predictor = nn.Linear(hidden_dim, in_dim)  # for masked node reconstruction
        self.edge_predictor = nn.Bilinear(hidden_dim, hidden_dim, 1)  # for link prediction

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.input_proj(x)  # shape: [num_nodes, hidden_dim]

        for layer in self.transformer_layers:
            x = layer(x, edge_index).relu()

        # -------- Node Prediction --------
        if hasattr(data, 'masked_nodes'):
            masked_nodes = data.masked_nodes.bool()
            node_logits = self.node_predictor(x)  # shape: [num_nodes, 45]
        else:
            node_logits = None

        # -------- Edge Prediction --------
        # Predict edge logits for all pairs
        edge_logits = torch.matmul(x, x.T)  # [num_nodes, num_nodes]
        edge_logits = torch.sigmoid(edge_logits)

        return {
            'node_logits': node_logits,
            'edge_logits': edge_logits,
            'final_node_embeddings': x
        }


In [25]:
def node_reconstruction_loss(output_logits, target_x, masked_nodes):
    """
    output_logits: [num_nodes, 45]
    target_x:      [num_nodes, 45]
    masked_nodes:  [num_nodes] (bool)
    """
    loss_fn = nn.CrossEntropyLoss()
    # Convert one-hot targets to class indices
    target_classes = target_x.argmax(dim=1)
    # Only use masked node indices
    masked_indices = masked_nodes.nonzero(as_tuple=True)[0]
    return loss_fn(output_logits[masked_indices], target_classes[masked_indices])

def edge_reconstruction_loss(pred_adj, target_edge_index, num_nodes):
    """
    pred_adj: [num_nodes, num_nodes] - predicted link logits
    target_edge_index: [2, num_edges]
    """
    # Create ground-truth adjacency matrix
    true_adj = torch.zeros_like(pred_adj)
    true_adj[target_edge_index[0], target_edge_index[1]] = 1.0

    # Binary Cross-Entropy over all pairs
    loss_fn = nn.BCELoss()
    return loss_fn(pred_adj.view(-1), true_adj.view(-1))


In [29]:
def evaluate_batch_node_accuracy(node_logits, target_x, masked_nodes):
    """
    node_logits: [num_nodes, 45]
    target_x:    [num_nodes, 45]
    masked_nodes: [num_nodes] (bool)
    """
    target_classes = target_x.argmax(dim=1)              # [num_nodes]
    predicted_classes = node_logits.argmax(dim=1)        # [num_nodes]

    masked_indices = masked_nodes.nonzero(as_tuple=True)[0]
    masked_targets = target_classes[masked_indices]
    masked_preds = predicted_classes[masked_indices]

    correct = (masked_targets == masked_preds).sum().item()
    total = masked_targets.size(0)

    return correct, total

In [30]:
from tqdm import tqdm

def train(model, dataloader, optimizer, device):
    model.train()
    
    total_loss = 0
    correct_masked = 0
    total_masked = 0
    
    for batch in tqdm(dataloader):
        data = batch['input'].to(device)
        target = batch['target'].to(device)

        optimizer.zero_grad()
        output = model(data)

        node_loss = node_reconstruction_loss(
            output['node_logits'],
            target.x,
            data.masked_nodes
        )

        edge_loss = edge_reconstruction_loss(
            output['edge_logits'],
            target.edge_index,
            target.num_nodes
        )

        loss = node_loss + edge_loss
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        correct, total = evaluate_batch_node_accuracy(
        output['node_logits'], target.x, data.masked_nodes
    )
        correct_masked += correct
        total_masked += total

    accuracy = 100.0 * correct_masked / max(1, total_masked)
    print(f"Train Loss = {total_loss:.4f} | Masked Node Accuracy = {accuracy:.2f}%")

    return total_loss / len(dataloader)


In [31]:
@torch.no_grad()
def validate(model, dataloader, device):
    model.eval()
    total_loss = 0

    for batch in dataloader:
        data = batch['input'].to(device)
        target = batch['target'].to(device)

        output = model(data)

        node_loss = node_reconstruction_loss(
            output['node_logits'],
            target.x,
            data.masked_nodes
        )

        edge_loss = edge_reconstruction_loss(
            output['edge_logits'],
            target.edge_index,
            target.num_nodes
        )

        loss = node_loss + edge_loss
        total_loss += loss.item()

    return total_loss / len(dataloader)


In [32]:
device = torch.device("cpu")
model = GraphTransformer().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(1, 51):
    train_loss = train(model, train_loader, optimizer, device)
    val_loss = validate(model, val_loader, device)
    print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")

  0%|          | 0/405 [00:00<?, ?it/s]

100%|██████████| 405/405 [00:05<00:00, 68.80it/s]


Train Loss = 1036.9818 | Masked Node Accuracy = 88.54%
Epoch 1: Train Loss = 2.5604, Val Loss = 1.4051


100%|██████████| 405/405 [00:05<00:00, 71.01it/s]


Train Loss = 490.3507 | Masked Node Accuracy = 100.00%
Epoch 2: Train Loss = 1.2107, Val Loss = 1.0868


100%|██████████| 405/405 [00:05<00:00, 71.40it/s]


Train Loss = 414.9390 | Masked Node Accuracy = 100.00%
Epoch 3: Train Loss = 1.0245, Val Loss = 0.9773


100%|██████████| 405/405 [00:05<00:00, 72.00it/s]


Train Loss = 382.6298 | Masked Node Accuracy = 100.00%
Epoch 4: Train Loss = 0.9448, Val Loss = 0.9180


100%|██████████| 405/405 [00:05<00:00, 71.09it/s]


Train Loss = 363.2556 | Masked Node Accuracy = 100.00%
Epoch 5: Train Loss = 0.8969, Val Loss = 0.8800


100%|██████████| 405/405 [00:05<00:00, 70.08it/s]


Train Loss = 350.6107 | Masked Node Accuracy = 100.00%
Epoch 6: Train Loss = 0.8657, Val Loss = 0.8544


100%|██████████| 405/405 [00:05<00:00, 70.22it/s]


Train Loss = 341.6485 | Masked Node Accuracy = 100.00%
Epoch 7: Train Loss = 0.8436, Val Loss = 0.8352


100%|██████████| 405/405 [00:05<00:00, 70.48it/s]


Train Loss = 334.8150 | Masked Node Accuracy = 100.00%
Epoch 8: Train Loss = 0.8267, Val Loss = 0.8204


100%|██████████| 405/405 [00:05<00:00, 69.78it/s]


Train Loss = 329.4289 | Masked Node Accuracy = 100.00%
Epoch 9: Train Loss = 0.8134, Val Loss = 0.8084


100%|██████████| 405/405 [00:05<00:00, 70.17it/s]


Train Loss = 324.9655 | Masked Node Accuracy = 100.00%
Epoch 10: Train Loss = 0.8024, Val Loss = 0.7981


100%|██████████| 405/405 [00:05<00:00, 70.36it/s]


Train Loss = 321.0574 | Masked Node Accuracy = 100.00%
Epoch 11: Train Loss = 0.7927, Val Loss = 0.7893


100%|██████████| 405/405 [00:05<00:00, 70.28it/s]


Train Loss = 317.7541 | Masked Node Accuracy = 100.00%
Epoch 12: Train Loss = 0.7846, Val Loss = 0.7816


100%|██████████| 405/405 [00:05<00:00, 69.89it/s]


Train Loss = 314.8401 | Masked Node Accuracy = 100.00%
Epoch 13: Train Loss = 0.7774, Val Loss = 0.7748


100%|██████████| 405/405 [00:05<00:00, 69.45it/s]


Train Loss = 312.2913 | Masked Node Accuracy = 100.00%
Epoch 14: Train Loss = 0.7711, Val Loss = 0.7693


100%|██████████| 405/405 [00:05<00:00, 69.89it/s]


Train Loss = 310.1908 | Masked Node Accuracy = 100.00%
Epoch 15: Train Loss = 0.7659, Val Loss = 0.7645


100%|██████████| 405/405 [00:06<00:00, 65.51it/s]


Train Loss = 308.4405 | Masked Node Accuracy = 100.00%
Epoch 16: Train Loss = 0.7616, Val Loss = 0.7604


100%|██████████| 405/405 [00:05<00:00, 69.87it/s]


Train Loss = 306.9544 | Masked Node Accuracy = 100.00%
Epoch 17: Train Loss = 0.7579, Val Loss = 0.7569


100%|██████████| 405/405 [00:05<00:00, 70.11it/s]


Train Loss = 305.6083 | Masked Node Accuracy = 100.00%
Epoch 18: Train Loss = 0.7546, Val Loss = 0.7539


100%|██████████| 405/405 [00:05<00:00, 69.60it/s]


Train Loss = 304.4156 | Masked Node Accuracy = 100.00%
Epoch 19: Train Loss = 0.7516, Val Loss = 0.7511


100%|██████████| 405/405 [00:05<00:00, 67.97it/s]


Train Loss = 303.3885 | Masked Node Accuracy = 100.00%
Epoch 20: Train Loss = 0.7491, Val Loss = 0.7488


100%|██████████| 405/405 [00:05<00:00, 68.85it/s]


Train Loss = 302.4587 | Masked Node Accuracy = 100.00%
Epoch 21: Train Loss = 0.7468, Val Loss = 0.7467


100%|██████████| 405/405 [00:05<00:00, 69.48it/s]


Train Loss = 301.6597 | Masked Node Accuracy = 100.00%
Epoch 22: Train Loss = 0.7448, Val Loss = 0.7448


100%|██████████| 405/405 [00:05<00:00, 70.12it/s]


Train Loss = 301.0043 | Masked Node Accuracy = 100.00%
Epoch 23: Train Loss = 0.7432, Val Loss = 0.7432


100%|██████████| 405/405 [00:05<00:00, 70.23it/s]


Train Loss = 300.3483 | Masked Node Accuracy = 100.00%
Epoch 24: Train Loss = 0.7416, Val Loss = 0.7417


100%|██████████| 405/405 [00:05<00:00, 71.92it/s]


Train Loss = 299.6976 | Masked Node Accuracy = 100.00%
Epoch 25: Train Loss = 0.7400, Val Loss = 0.7403


100%|██████████| 405/405 [00:05<00:00, 71.62it/s]


Train Loss = 299.2076 | Masked Node Accuracy = 100.00%
Epoch 26: Train Loss = 0.7388, Val Loss = 0.7391


100%|██████████| 405/405 [00:05<00:00, 71.99it/s]


Train Loss = 298.7729 | Masked Node Accuracy = 100.00%
Epoch 27: Train Loss = 0.7377, Val Loss = 0.7381


100%|██████████| 405/405 [00:05<00:00, 71.52it/s]


Train Loss = 298.4030 | Masked Node Accuracy = 100.00%
Epoch 28: Train Loss = 0.7368, Val Loss = 0.7371


100%|██████████| 405/405 [00:05<00:00, 73.27it/s]


Train Loss = 298.0366 | Masked Node Accuracy = 100.00%
Epoch 29: Train Loss = 0.7359, Val Loss = 0.7362


100%|██████████| 405/405 [00:05<00:00, 71.78it/s]


Train Loss = 297.6552 | Masked Node Accuracy = 100.00%
Epoch 30: Train Loss = 0.7350, Val Loss = 0.7354


100%|██████████| 405/405 [00:05<00:00, 72.87it/s]


Train Loss = 297.3910 | Masked Node Accuracy = 100.00%
Epoch 31: Train Loss = 0.7343, Val Loss = 0.7347


100%|██████████| 405/405 [00:05<00:00, 72.16it/s]


Train Loss = 297.1058 | Masked Node Accuracy = 100.00%
Epoch 32: Train Loss = 0.7336, Val Loss = 0.7340


100%|██████████| 405/405 [00:05<00:00, 71.66it/s]


Train Loss = 296.8399 | Masked Node Accuracy = 100.00%
Epoch 33: Train Loss = 0.7329, Val Loss = 0.7334


100%|██████████| 405/405 [00:05<00:00, 71.82it/s]


Train Loss = 296.6330 | Masked Node Accuracy = 100.00%
Epoch 34: Train Loss = 0.7324, Val Loss = 0.7329


100%|██████████| 405/405 [00:05<00:00, 71.63it/s]


Train Loss = 296.4183 | Masked Node Accuracy = 100.00%
Epoch 35: Train Loss = 0.7319, Val Loss = 0.7324


100%|██████████| 405/405 [00:05<00:00, 71.67it/s]


Train Loss = 296.1892 | Masked Node Accuracy = 100.00%
Epoch 36: Train Loss = 0.7313, Val Loss = 0.7319


100%|██████████| 405/405 [00:05<00:00, 72.04it/s]


Train Loss = 296.0433 | Masked Node Accuracy = 100.00%
Epoch 37: Train Loss = 0.7310, Val Loss = 0.7315


100%|██████████| 405/405 [00:05<00:00, 72.88it/s]


Train Loss = 295.8627 | Masked Node Accuracy = 100.00%
Epoch 38: Train Loss = 0.7305, Val Loss = 0.7311


100%|██████████| 405/405 [00:05<00:00, 72.33it/s]


Train Loss = 295.7302 | Masked Node Accuracy = 100.00%
Epoch 39: Train Loss = 0.7302, Val Loss = 0.7308


100%|██████████| 405/405 [00:05<00:00, 71.40it/s]


Train Loss = 295.5675 | Masked Node Accuracy = 100.00%
Epoch 40: Train Loss = 0.7298, Val Loss = 0.7304


100%|██████████| 405/405 [00:05<00:00, 72.07it/s]


Train Loss = 295.4643 | Masked Node Accuracy = 100.00%
Epoch 41: Train Loss = 0.7295, Val Loss = 0.7301


100%|██████████| 405/405 [00:05<00:00, 70.95it/s]


Train Loss = 295.3853 | Masked Node Accuracy = 100.00%
Epoch 42: Train Loss = 0.7293, Val Loss = 0.7299


100%|██████████| 405/405 [00:05<00:00, 72.87it/s]


Train Loss = 295.1599 | Masked Node Accuracy = 100.00%
Epoch 43: Train Loss = 0.7288, Val Loss = 0.7296


100%|██████████| 405/405 [00:05<00:00, 72.27it/s]


Train Loss = 295.0850 | Masked Node Accuracy = 100.00%
Epoch 44: Train Loss = 0.7286, Val Loss = 0.7294


100%|██████████| 405/405 [00:05<00:00, 71.85it/s]


Train Loss = 295.0150 | Masked Node Accuracy = 100.00%
Epoch 45: Train Loss = 0.7284, Val Loss = 0.7292


100%|██████████| 405/405 [00:05<00:00, 73.37it/s]


Train Loss = 294.9188 | Masked Node Accuracy = 100.00%
Epoch 46: Train Loss = 0.7282, Val Loss = 0.7290


100%|██████████| 405/405 [00:05<00:00, 71.52it/s]


Train Loss = 294.8206 | Masked Node Accuracy = 100.00%
Epoch 47: Train Loss = 0.7280, Val Loss = 0.7288


100%|██████████| 405/405 [00:05<00:00, 71.81it/s]


Train Loss = 294.8225 | Masked Node Accuracy = 100.00%
Epoch 48: Train Loss = 0.7280, Val Loss = 0.7286


100%|██████████| 405/405 [00:05<00:00, 72.19it/s]


Train Loss = 294.7276 | Masked Node Accuracy = 100.00%
Epoch 49: Train Loss = 0.7277, Val Loss = 0.7284


100%|██████████| 405/405 [00:05<00:00, 71.66it/s]


Train Loss = 294.6377 | Masked Node Accuracy = 100.00%
Epoch 50: Train Loss = 0.7275, Val Loss = 0.7283
