In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.sparse import csgraph
from scipy.sparse.linalg import eigsh
from torch_geometric.data import Data
import networkx as nx
import random
import copy
from torch_geometric.utils import to_networkx
from torch.utils.data import Dataset, DataLoader


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

class MaskedGraphDataset(Dataset):
    def __init__(self, graph_data_obj_ls, subgraph_data_obj_ls):
        self.inputs = []
        self.targets = []

        for full_graph, masked_versions in zip(graph_data_obj_ls, subgraph_data_obj_ls):
            for masked_graph in masked_versions:
                self.inputs.append(masked_graph)  # G''
                self.targets.append(full_graph)   # G'

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_graph = self.inputs[idx]
        target_graph = self.targets[idx]

        # Ensure that all tensors are returned
        return {
            'input': input_graph,
            'target': target_graph
        }


In [3]:

inc_matrix_aug = np.loadtxt("Aug_inc_matrix")
inc_matrix_aug = inc_matrix_aug.reshape(-1,50)

num_nodes, num_edges = inc_matrix_aug.shape

# --- Step 2: Convert to edge_index for PyG (multi-edges allowed) ---
edge_list = []
for j in range(num_edges):
    col = inc_matrix_aug[:, j]
    src = np.where(col == -1)[0]
    dst = np.where(col == 1)[0]
    if len(src) == 1 and len(dst) == 1:
        edge_list.append((src[0], dst[0]))  # directed edge

edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()  # shape [2, num_edges]
x = torch.eye(45, dtype=torch.float)

# --- Step 3: Create PyG Data object ---
data_inp= Data(x=x, edge_index=edge_index, )


def pyg_data_to_nx_multigraph(data):
    G = nx.MultiDiGraph()

    # Step 1: Add all nodes with features
    for i in range(data.num_nodes):
        G.add_node(i, x=data.x[i].tolist())  # attach node features

    # Step 2: Add all edges (with support for multiple edges)
    edge_list = data.edge_index.t().tolist()
    G.add_edges_from(edge_list)

    return G
G = pyg_data_to_nx_multigraph(data=data_inp)
def generate_connected_subgraphs(G, k, n, seed=None):
    if seed is not None:
        random.seed(seed)

    if G.number_of_nodes() <= k:
        raise ValueError("Cannot remove more nodes than exist in the graph.")

    subgraphs = []
    attempts = 0
    max_attempts = 100 * n  # safety to avoid infinite loops

    while len(subgraphs) < n and attempts < max_attempts:
        attempts += 1
        nodes_to_remove = random.sample(list(G.nodes()), k)
        G_sub = G.copy()
        G_sub.remove_nodes_from(nodes_to_remove)

        if nx.is_weakly_connected(G_sub):
            subgraphs.append(G_sub)

    return subgraphs
graph_data_obj_ls = []
subgraph_ls = []
for k in range(5):
    subgraphs = generate_connected_subgraphs(G, k, n=10, seed=123)
    subgraph_ls.extend(subgraphs)

for nx_graph in subgraph_ls:
    # Get all edges with duplicates preserved
    edge_list = [(u, v) for u, v, _ in nx_graph.edges(keys=True)]
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()

    # Build identity features using original node indices
    all_nodes = list(nx_graph.nodes())
    num_nodes_global = 45
    x = torch.eye(num_nodes_global)  # [45, 45]
    node_mask = torch.zeros_like(x)  # [45, 45]

    for node in all_nodes:
        node_mask[node] = x[node]  # Keep features only for nodes in this subgraph

    x_subset = node_mask

    data = Data(x=x_subset, edge_index=edge_index)
    graph_data_obj_ls.append(data)
subgraph_data_obj_ls = []

for data in graph_data_obj_ls:
    G_nx = to_networkx(data, to_undirected=False)
    incidence_matrix = nx.incidence_matrix(G_nx, oriented=True).toarray()
    rank = np.linalg.matrix_rank(incidence_matrix)
    num_edges = data.edge_index.size(1)
    num_nodes = data.num_nodes

    masked_graphs_per_data = []

    for edges_to_remove in range(rank, min(rank + 6, num_edges)):
        for _ in range(15):
            if num_edges <= edges_to_remove:
                continue

            data_copy = copy.deepcopy(data)

            # -------------------------------
            # 1. Mask edges
            edge_indices = list(range(num_edges))
            to_remove = random.sample(edge_indices, edges_to_remove)
            mask = torch.ones(num_edges, dtype=torch.bool)
            mask[to_remove] = False
            data_copy.edge_index = data.edge_index[:, mask]

            if hasattr(data, 'edge_attr') and data.edge_attr is not None:
                data_copy.edge_attr = data.edge_attr[mask]

            # -------------------------------
            # 2. Mask nodes (retain ~90% randomly)
            node_mask = torch.ones(45, dtype=torch.bool)
            total_nodes = 45
            num_nodes_to_mask = int(0.1 * total_nodes)
            nodes_to_mask = random.sample(range(45), num_nodes_to_mask)
            node_mask[nodes_to_mask] = False

            data_copy.masked_nodes = node_mask  # Add this attribute to use later

            masked_graphs_per_data.append(data_copy)

    subgraph_data_obj_ls.append(masked_graphs_per_data)


In [62]:

from torch.utils.data import random_split

# Create full dataset
full_dataset = MaskedGraphDataset(graph_data_obj_ls, subgraph_data_obj_ls)

# Split: 80% train, 20% validation
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

# PyTorch Geometric uses a custom collate_fn
from torch_geometric.loader import DataLoader as PyGDataLoader

train_loader = PyGDataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = PyGDataLoader(val_dataset, batch_size=1, shuffle=False)


In [63]:
from torch_geometric.nn import TransformerConv

class GraphTransformer(nn.Module):
    def __init__(self, num_nodes=45, in_dim=45, hidden_dim=128, num_heads=4, num_layers=3):
        super(GraphTransformer, self).__init__()
        self.num_nodes = num_nodes
        self.in_dim = in_dim
        self.hidden_dim = hidden_dim

        self.input_proj = nn.Linear(in_dim, hidden_dim)

        self.transformer_layers = nn.ModuleList([
            TransformerConv(hidden_dim, hidden_dim // num_heads, heads=num_heads, dropout=0.1)
            for _ in range(num_layers)
        ])

        self.node_predictor = nn.Linear(hidden_dim, in_dim)  # Node classification
        self.edge_predictor = nn.Bilinear(hidden_dim, hidden_dim, 1)  # Edge classification

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.input_proj(x)

        for layer in self.transformer_layers:
            x = layer(x, edge_index).relu()

        node_logits = None
        if hasattr(data, 'masked_nodes'):
            masked_nodes = data.masked_nodes.bool()
            node_logits = self.node_predictor(x)

        # Edge logits for all possible pairs (optional: restrict to sampled pairs in loss)
        edge_logits_raw = torch.matmul(x, x.T)
        edge_logits = torch.sigmoid(edge_logits_raw)

        # Remove self-loop predictions by zeroing diagonal
        edge_logits.fill_diagonal_(0)

        return {
            'node_logits': node_logits,
            'edge_logits': edge_logits,
            'edge_logits_raw': edge_logits_raw,
            'final_node_embeddings': x
        }

def node_reconstruction_loss(output_logits, target_x, masked_nodes):
    loss_fn = nn.CrossEntropyLoss()
    target_classes = target_x.argmax(dim=1)
    masked_indices = masked_nodes.nonzero(as_tuple=True)[0]
    return loss_fn(output_logits[masked_indices], target_classes[masked_indices])


def edge_reconstruction_loss(edge_logits_raw, edge_index, num_nodes):
    # Ground truth adjacency matrix
    adj_true = torch.zeros((num_nodes, num_nodes), device=edge_logits_raw.device)
    adj_true[edge_index[0], edge_index[1]] = 1.0

    # Flatten for BCEWithLogitsLoss
    loss = F.binary_cross_entropy_with_logits(
        edge_logits_raw.view(-1),
        adj_true.view(-1)
    )

    return loss

def pad_to_full_graph(pred_adj, full_size=45):
    """
    Pads a square predicted adjacency matrix to a full_size x full_size matrix.
    """
    padded = torch.zeros(full_size, full_size, device=pred_adj.device)
    size = pred_adj.size(0)
    if size > full_size:
        raise ValueError(f"Predicted adjacency size {size} exceeds full size {full_size}")
    padded[:size, :size] = pred_adj
    return padded

def graph_edit_distance_loss(output, target, full_size=45):
    """
    Computes GED-like MSE loss between predicted and true adjacency matrices.
    """
    edge_logits_raw = output['edge_logits_raw']  # Expect shape [N, N] or [N*N]

    # Determine number of nodes in prediction
    if edge_logits_raw.dim() == 1:
        # Flattened, so reshape
        num_nodes = int(edge_logits_raw.size(0) ** 0.5)
        adj_pred = edge_logits_raw.view(num_nodes, num_nodes)
    elif edge_logits_raw.dim() == 2:
        # Already square
        adj_pred = edge_logits_raw
        num_nodes = adj_pred.size(0)
    else:
        raise ValueError("edge_logits_raw must be 1D or 2D")

    # Safety check
    if adj_pred.size(0) != adj_pred.size(1):
        raise ValueError("Predicted adjacency matrix must be square")

    # Pad predicted adjacency
    adj_pred_padded = pad_to_full_graph(adj_pred, full_size=full_size)

    # True adjacency matrix
    adj_true = torch.zeros(full_size, full_size, device=target.x.device)
    adj_true[target.edge_index[0], target.edge_index[1]] = 1.0

    # Match the shapes
    if adj_pred_padded.shape != adj_true.shape:
        raise ValueError(f"Shape mismatch: predicted {adj_pred_padded.shape}, true {adj_true.shape}")

    # MSE Loss
    adj_mse_loss = F.mse_loss(adj_pred_padded, adj_true)
    return adj_mse_loss


# Combined loss (for training)
def total_loss_fn(output, data):
    node_loss = node_reconstruction_loss(output['node_logits'], data.x, data.masked_nodes)
    edge_loss = edge_reconstruction_loss(output['edge_logits'], data.edge_index, data.x.size(0))

    # Optional GED
    ged_loss = graph_edit_distance_loss(output['edge_logits'], torch.zeros_like(output['edge_logits']))  # Replace with true adj if available

    total = node_loss + edge_loss + 0.1 * ged_loss
    return total


In [64]:
from sklearn.metrics import confusion_matrix, f1_score
def evaluate_edge_metrics(pred_adj, target_edge_index, threshold=0.5):
    num_nodes = pred_adj.size(0)
    true_adj = torch.zeros_like(pred_adj)
    true_adj[target_edge_index[0], target_edge_index[1]] = 1.0

    pred_binary = (pred_adj > threshold).float()

    y_true = true_adj.view(-1).cpu().numpy()
    y_pred = pred_binary.view(-1).cpu().numpy()

    acc = (y_true == y_pred).sum() / y_true.shape[0]
    f1 = f1_score(y_true, y_pred)  # binary

    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()

    return acc, f1, int(tp), int(fp), int(fn), int(tn)

In [72]:
def unbatch_output(output, data_list):
    """
    Splits batched output['edge_logits_raw'] into per-graph tensors.
    Assumes edge_logits_raw is 1D concatenated logits from all graphs.
    """
    edge_logits_raw = output['edge_logits_raw']
    output_list = []

    start = 0
    for data in data_list:
        n = data.num_nodes
        size = n * n
        raw = edge_logits_raw[start:start + size].view(n, n)
        out_dict = {
            'edge_logits_raw': raw,
            'edge_logits': torch.sigmoid(raw),
            'node_logits': output['node_logits'][data.batch == 0]  # If needed
        }
        output_list.append(out_dict)
        start += size

    return output_list


In [73]:
from tqdm import tqdm

def train(model, dataloader, optimizer, device):
    torch.autograd.set_detect_anomaly(True)
    model.train()

    total_loss = 0
    total_edge_f1 = 0
    batch_count = 0

    for batch in tqdm(dataloader):
        data = batch['input'].to(device)     # Batched input graphs
        target = batch['target'].to(device)  # Batched target graphs

        optimizer.zero_grad()
        output = model(data)

        # Node reconstruction loss (can be batched)
        node_loss = node_reconstruction_loss(
            output['node_logits'],
            target.x,
            data.masked_nodes
        )

        # Edge reconstruction loss (can be batched)
        edge_loss = edge_reconstruction_loss(
            output['edge_logits_raw'],
            target.edge_index,
            target.num_nodes
        )

        # GED loss (must be per graph in batch)
        data_list = data.to_data_list()
        target_list = target.to_data_list()
        output_list = unbatch_output(output, data_list)  # Helper splits batched outputs

        ged_loss = 0.0
        for out, tgt in zip(output_list, target_list):
            try:
                ged_loss += graph_edit_distance_loss(out, tgt, full_size=45)
            except Exception as e:
                print(f"[Warning] GED Loss skipped for one graph: {e}")
                ged_loss += 0.0  # skip but don’t crash

        ged_loss = ged_loss / len(output_list)

        # Total loss
        loss = node_loss + edge_loss + ged_loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        batch_count += 1

        # Edge metrics (optional: you can aggregate over whole batch too)
        edge_acc, edge_f1, *_ = evaluate_edge_metrics(output['edge_logits'], target.edge_index)
        total_edge_f1 += edge_f1

    avg_loss = total_loss / batch_count
    avg_edge_f1 = total_edge_f1 / batch_count

    print(f"Train Loss = {avg_loss:.4f} | Edge F1 avg batch = {avg_edge_f1:.2f}%")
    return avg_loss


In [74]:
@torch.no_grad()
def validate(model, dataloader, device):
    model.eval()
    total_loss = 0

    for batch in dataloader:
        data = batch['input'].to(device)
        target = batch['target'].to(device)

        output = model(data)

        node_loss = node_reconstruction_loss(
            output['node_logits'],
            target.x,
            data.masked_nodes
        )

        edge_loss = edge_reconstruction_loss(
            output['edge_logits'],
            target.edge_index,
            target.num_nodes
        )

        loss = node_loss + edge_loss
        total_loss += loss.item()

    return total_loss / len(dataloader)


In [76]:
device = torch.device("cpu")
model = GraphTransformer().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(1, 51):
    train_loss = train(model, train_loader, optimizer, device)
    val_loss = validate(model, val_loader, device)
    print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")

100%|██████████| 3240/3240 [05:25<00:00,  9.94it/s]


Train Loss = 1.3147 | Edge F1 avg batch = 0.06%
Epoch 1: Train Loss = 1.3147, Val Loss = 1.0026


100%|██████████| 3240/3240 [05:30<00:00,  9.80it/s]


Train Loss = 0.7849 | Edge F1 avg batch = 0.07%
Epoch 2: Train Loss = 0.7849, Val Loss = 0.9712


100%|██████████| 3240/3240 [05:30<00:00,  9.80it/s]


Train Loss = 0.7423 | Edge F1 avg batch = 0.09%
Epoch 3: Train Loss = 0.7423, Val Loss = 0.9622


100%|██████████| 3240/3240 [05:31<00:00,  9.78it/s]


Train Loss = 0.7268 | Edge F1 avg batch = 0.15%
Epoch 4: Train Loss = 0.7268, Val Loss = 0.9588


100%|██████████| 3240/3240 [05:30<00:00,  9.80it/s]


Train Loss = 0.7204 | Edge F1 avg batch = 0.23%
Epoch 5: Train Loss = 0.7204, Val Loss = 0.9574


100%|██████████| 3240/3240 [05:31<00:00,  9.78it/s]


Train Loss = 0.7192 | Edge F1 avg batch = 0.33%
Epoch 6: Train Loss = 0.7192, Val Loss = 0.9570


100%|██████████| 3240/3240 [05:32<00:00,  9.76it/s]


Train Loss = 0.7171 | Edge F1 avg batch = 0.36%
Epoch 7: Train Loss = 0.7171, Val Loss = 0.9567


100%|██████████| 3240/3240 [05:30<00:00,  9.79it/s]


Train Loss = 0.7164 | Edge F1 avg batch = 0.36%
Epoch 8: Train Loss = 0.7164, Val Loss = 0.9566


100%|██████████| 3240/3240 [05:33<00:00,  9.73it/s]


Train Loss = 0.7176 | Edge F1 avg batch = 0.39%
Epoch 9: Train Loss = 0.7176, Val Loss = 0.9565


 19%|█▉        | 617/3240 [01:02<04:23,  9.95it/s]


KeyboardInterrupt: 