In [None]:
from __future__ import annotations
import os
import math
import json
import argparse
from pathlib import Path
from typing import List, Tuple, Optional

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool, MessagePassing
from torch_geometric.utils import add_self_loops, degree

from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, average_precision_score, confusion_matrix
)

In [None]:
# ---------------------------
# Utils
# ---------------------------

def get_device(prefer_gpu_index: int = 0) -> torch.device:
    if torch.cuda.is_available() and torch.cuda.device_count() > prefer_gpu_index:
        return torch.device(f"cuda:{prefer_gpu_index}")
    return torch.device("cpu")

In [None]:
# ---------------------------
# Dataset
# ---------------------------

class EDKGDataset(InMemoryDataset):
    """
    Expected layout:
      data_root/
        Graph_label.txt        # two columns: compound_id, label
        0/
          Graph_index.txt
          Graph_edge_index_direct.txt
        1/
          ...
    """
    def __init__(self, data_root: str, cache_root: str, transform=None, pre_transform=None):
        self.data_root = Path(data_root)
        self._processed_dir = Path(cache_root)
        self._processed_dir.mkdir(parents=True, exist_ok=True)
        super().__init__(root=str(self._processed_dir), transform=transform, pre_transform=pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0], weights_only=False)

    @property
    def raw_file_names(self) -> List[str]:
        return []

    @property
    def processed_file_names(self) -> List[str]:
        return ["data.pt"]

    def process(self) -> None:
        import pandas as pd

        label_path = self.data_root / "Graph_label.txt"
        if not label_path.exists():
            raise FileNotFoundError(f"Missing label table: {label_path}")

        node_file = "Graph_index.txt"
        edge_file = "Graph_edge_index_direct.txt"

        labels_df = pd.read_csv(label_path, header=None)  # [compound_id, label]
        compounds = [p for p in self.data_root.iterdir() if p.is_dir() and p.name.isdigit()]
        compounds = sorted(compounds, key=lambda p: int(p.name))

        data_list: List[Data] = []
        for comp_dir in compounds:
            idx = int(comp_dir.name)
            np_path = comp_dir / node_file
            ep_path = comp_dir / edge_file
            if not np_path.exists() or not ep_path.exists():
                # quietly skip incomplete compound
                continue

            # node features
            ndf = pd.read_csv(np_path, header=None)
            x = torch.tensor(ndf.values, dtype=torch.float)

            # edges and edge features
            edf = pd.read_csv(ep_path, header=None)
            edge_index = torch.tensor(edf.iloc[:, 0:2].T.values, dtype=torch.long)
            edge_attr = torch.tensor(edf.iloc[:, 2:].values, dtype=torch.float)

            # label lookup
            row = labels_df[labels_df.iloc[:, 0] == idx]
            if row.empty:
                # fallback by position if needed
                row = labels_df.iloc[[idx]]
            y = torch.tensor(int(row.iloc[0, 1]), dtype=torch.long).view(1)

            data_list.append(Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y))

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [None]:
# ---------------------------
# Model
# ---------------------------

class GCNConvEdge(MessagePassing):
    """GCN-like layer with linear edge feature mapping and concatenated message."""
    def __init__(self, in_channels: int, out_channels: int, edge_channels: int):
        super().__init__(aggr="add")
        self.lin_node = nn.Linear(in_channels, out_channels, bias=False)
        self.lin_edge = nn.Linear(edge_channels, out_channels, bias=False)
        self.bias = nn.Parameter(torch.zeros(2 * out_channels))
        self.reset_parameters()

    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.lin_node.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.lin_edge.weight, a=math.sqrt(5))
        nn.init.zeros_(self.bias)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # add self-loops
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # map node and edge features
        x = self.lin_node(x)
        ex = self.lin_edge(edge_attr)

        # pad edge embeddings for self-loops (one per node)
        ex = torch.cat([ex, torch.zeros((x.size(0), ex.size(1)), device=ex.device, dtype=ex.dtype)], dim=0)

        # normalized aggregation coefficients
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        out = self.propagate(edge_index, x=x, norm=norm, ex=ex)
        out = out + self.bias
        return out, ex

    def message(self, x_j: torch.Tensor, norm: torch.Tensor, ex: torch.Tensor) -> torch.Tensor:
        return norm.view(-1, 1) * torch.cat([x_j, ex], dim=1)

class EDKGGCN(nn.Module):
    def __init__(self, in_node: int, in_edge: int, h1: int, h2: int, h3: int, num_classes: int):
        super().__init__()
        self.conv1 = GCNConvEdge(in_node,       h1, in_edge)
        self.conv2 = GCNConvEdge(2 * h1,        h2, h1)
        self.conv3 = GCNConvEdge(2 * h2,        h3, h2)
        self.fc    = nn.Linear(2 * h3, num_classes)

    def forward(self, data: Data) -> torch.Tensor:
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        x, ex = self.conv1(x, edge_index, edge_attr)
        x = F.relu(x)
        x, ex = self.conv2(x, edge_index, ex)
        x = F.relu(x)
        x, _  = self.conv3(x, edge_index, ex)
        x = global_mean_pool(x, batch)
        x = F.dropout(x, p=0.5, training=self.training)
        return self.fc(x)

In [None]:
# ---------------------------
# Evaluation
# ---------------------------

@torch.no_grad()
def evaluate_binary(model: nn.Module, loader: DataLoader, device: torch.device,
                    threshold: float = 0.5, average: str = "binary") -> dict:
    model.eval()
    y_true, y_prob, y_pred = [], [], []

    for batch in loader:
        batch = batch.to(device)
        logits = model(batch)

        if logits.ndim == 1 or logits.shape[1] == 1:
            p1 = torch.sigmoid(logits.view(-1))
            pred = (p1 >= threshold).long()
        else:
            probs = torch.softmax(logits, dim=1)
            p1 = probs[:, 1]
            pred = probs.argmax(dim=1)

        labels = batch.y.view(-1).long()
        y_true.append(labels.detach().cpu().numpy())
        y_prob.append(p1.detach().cpu().numpy())
        y_pred.append(pred.detach().cpu().numpy())

    y_true = np.concatenate(y_true, axis=0)
    y_prob = np.concatenate(y_prob, axis=0)
    y_pred = np.concatenate(y_pred, axis=0)

    acc = accuracy_score(y_true, y_pred)
    pre = precision_score(y_true, y_pred, zero_division=0, average=average)
    rec = recall_score(y_true, y_pred, zero_division=0, average=average)
    f1  = f1_score(y_true, y_pred, zero_division=0, average=average)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = cm.ravel()

    try:
        roc_auc = roc_auc_score(y_true, y_prob)
    except Exception:
        roc_auc = None
    try:
        pr_auc = average_precision_score(y_true, y_prob)
    except Exception:
        pr_auc = None

    return {
        "TP": int(TP), "TN": int(TN), "FP": int(FP), "FN": int(FN),
        "accuracy": float(acc), "precision": float(pre),
        "recall": float(rec), "f1": float(f1),
        "roc_auc": None if roc_auc is None else float(roc_auc),
        "pr_auc":  None if pr_auc  is None else float(pr_auc),
    }


In [None]:
# ---------------------------
# Main
# ---------------------------

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_root", type=str, default="edkgdl_all_data")
    parser.add_argument("--cache_root", type=str, default="EDKG-DL_cache")
    parser.add_argument("--ckpt", type=str, required=True, help="Path to model_state.pt")
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--h1", type=int, default=50)
    parser.add_argument("--h2", type=int, default=20)
    parser.add_argument("--h3", type=int, default=60)
    parser.add_argument("--gpu", type=int, default=0, help="preferred GPU index")
    parser.add_argument("--report_path", type=str, default=None, help="Optional path to save JSON report")
    args = parser.parse_args()

    device = get_device(args.gpu)

    # Load dataset from cache or build it
    dataset = EDKGDataset(data_root=args.data_root, cache_root=args.cache_root)
    print(f"Loaded dataset: graphs={len(dataset)}, node_feats={dataset.num_node_features}, edge_feats={dataset.num_edge_features}, classes={dataset.num_classes}")

    # Build model using dataset dims
    model = EDKGGCN(
        in_node=dataset.num_node_features,
        in_edge=dataset.num_edge_features,
        h1=args.h1, h2=args.h2, h3=args.h3,
        num_classes=dataset.num_classes
    ).to(device)

    # Load checkpoint on the chosen device gracefully
    ckpt_path = Path(args.ckpt)
    if not ckpt_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
    state = torch.load(str(ckpt_path), map_location=device)
    model.load_state_dict(state, strict=True)

    # Full-dataset evaluation by default
    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)

    metrics = evaluate_binary(model, loader, device=device)
    print("\n=== Evaluation Results ===")
    print(f"TP: {metrics['TP']}")
    print(f"TN: {metrics['TN']}")
    print(f"FP: {metrics['FP']}")
    print(f"FN: {metrics['FN']}")
    print(f"Accuracy:  {metrics['accuracy']:.4f}")
    print(f"Precision: {metrics['precision']:.4f}")
    print(f"Recall:    {metrics['recall']:.4f}")
    print(f"F1-score:  {metrics['f1']:.4f}")
    print(f"ROC-AUC:   {metrics['roc_auc'] if metrics['roc_auc'] is not None else 'N/A'}")
    print(f"PR-AUC:    {metrics['pr_auc'] if metrics['pr_auc'] is not None else 'N/A'}")

    if args.report_path:
        report = {
            "data_root": args.data_root,
            "cache_root": args.cache_root,
            "ckpt": str(ckpt_path),
            "batch_size": args.batch_size,
            "hidden_sizes": [args.h1, args.h2, args.h3],
            "metrics": metrics
        }
        with open(args.report_path, "w", encoding="utf-8") as f:
            json.dump(report, f, indent=2)
        print(f"\nSaved JSON report to: {args.report_path}")

if __name__ == "__main__":
    main()