In [None]:
from __future__ import annotations
import os
import math
import json
import random
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Tuple, List, Optional

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Subset

from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool, MessagePassing
from torch_geometric.utils import add_self_loops, degree


In [None]:
# ---------------------------
# 2) Config
# ---------------------------

@dataclass
class TrainConfig:
    root_raw: str = "edkgdl_all_data"         # raw folder containing per-compound subfolders
    cache_root: str = "EDKG-DL_cache"         # where processed data.pt is written
    batch_size: int = 64
    test_ratio: float = 0.15                  # split ratio
    max_epochs: int = 200
    lr: float = 5e-4
    weight_decay: float = 5e-4
    hidden_size1: int = 60
    hidden_size2: int = 40
    hidden_size3: int = 30
    seed: int = 12345
    wandb_project: str = "EDCs-BT-GNN-classification"
    wandb_mode: str = "offline"               # keep offline default for portability
    wandb_dir: str = "wandb_runs"
    patience: int = 20                        # early stop patience (optional, can be 0 to disable)
    num_workers: int = 0                      # DataLoader workers

In [None]:
# ---------------------------
# 3) Dataset
# ---------------------------

class EDKGDataset(InMemoryDataset):
    """
    Expects a folder structure:
      root_raw/
        Graph_label.txt  (global labels, 2 columns: [compound_id, label])
        0/
          Graph_index.txt
          Graph_edge_index_direct.txt
        1/
          ...
        ...
    """
    def __init__(self, root_raw: str, cache_root: str, transform=None, pre_transform=None):
        self.root_raw = Path(root_raw)
        self._processed_dir = Path(cache_root)
        self._processed_dir.mkdir(parents=True, exist_ok=True)
        super().__init__(str(self._processed_dir), transform, pre_transform)
        # torch.load(..., weights_only=False) for compatibility with PyTorch>=2.6 default behavior change
        self.data, self.slices = torch.load(self.processed_paths[0], weights_only=False)

    @property
    def raw_file_names(self) -> List[str]:
        # Managed manually in process()
        return []

    @property
    def processed_file_names(self) -> List[str]:
        return ["data.pt"]

    def process(self) -> None:
        import pandas as pd

        # file names inside each compound folder
        node_file = "Graph_index.txt"
        edge_file = "Graph_edge_index_direct.txt"
        label_table = self.root_raw / "Graph_label.txt"

        if not label_table.exists():
            raise FileNotFoundError(f"Missing label table: {label_table}")

        labels_df = pd.read_csv(label_table, header=None)
        # Expect [compound_id, label] columns
        # build index of available compound folders
        compound_dirs = [p for p in self.root_raw.iterdir() if p.is_dir() and p.name.isdigit()]
        compound_dirs = sorted(compound_dirs, key=lambda p: int(p.name))

        data_list: List[Data] = []
        for comp_dir in compound_dirs:
            idx = int(comp_dir.name)
            node_path = comp_dir / node_file
            edge_path = comp_dir / edge_file

            if not node_path.exists() or not edge_path.exists():
                # silently skip incomplete entries
                continue

            # nodes
            y_node = pd.read_csv(node_path, header=None)             # shape [N, node_features]
            x = torch.tensor(y_node.values, dtype=torch.float)

            # edges
            e = pd.read_csv(edge_path, header=None)                  # columns: src, dst, edge_feats...
            edge_index = torch.tensor(e.iloc[:, 0:2].T.values, dtype=torch.long)
            edge_attr = torch.tensor(e.iloc[:, 2:].values, dtype=torch.float)

            # label from label table
            # We assume row for compound idx is present and label in column 1
            row = labels_df[labels_df.iloc[:, 0] == idx]
            if row.empty:
                # fallback to positional row
                row = labels_df.iloc[[idx]]
            label_value = int(row.iloc[0, 1])
            label = torch.tensor(label_value, dtype=torch.long).view(1)

            data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=label)
            data_list.append(data)

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [None]:
# ---------------------------
# 4) Model
# ---------------------------

class GCNConvEdge(MessagePassing):
    """GCN-like layer that incorporates edge attributes linearly."""
    def __init__(self, in_channels: int, out_channels: int, edge_channels: int):
        super().__init__(aggr="add")
        self.lin_node = nn.Linear(in_channels, out_channels, bias=False)
        self.lin_edge = nn.Linear(edge_channels, out_channels, bias=False)
        # learnable bias for concatenated message [x_j, ex]
        self.bias = nn.Parameter(torch.zeros(2 * out_channels))
        self.reset_parameters()

    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.lin_node.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.lin_edge.weight, a=math.sqrt(5))
        nn.init.zeros_(self.bias)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # add self loops
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        x = self.lin_node(x)                        # [N, C]
        ex = self.lin_edge(edge_attr)               # [E, C] (E = original edges)
        # after adding self-loops, we need to pad edge features for the new edges (one per node)
        ex_padded = torch.cat([
            ex,
            torch.zeros((x.size(0), ex.size(1)), device=ex.device, dtype=ex.dtype)
        ], dim=0)                                   # [E + N, C]

        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        out = self.propagate(edge_index, x=x, norm=norm, ex=ex_padded)
        out = out + self.bias
        return out, ex

    def message(self, x_j: torch.Tensor, norm: torch.Tensor, ex: torch.Tensor) -> torch.Tensor:
        # message = normalized concat(node, edge)
        return norm.view(-1, 1) * torch.cat([x_j, ex], dim=1)

class EDKGGCN(nn.Module):
    def __init__(self, in_node: int, in_edge: int, hidden1: int, hidden2: int, hidden3: int, num_classes: int):
        super().__init__()
        self.conv1 = GCNConvEdge(in_node, hidden1, in_edge)
        self.conv2 = GCNConvEdge(2 * hidden1, hidden2, hidden1)
        self.conv3 = GCNConvEdge(2 * hidden2, hidden3, hidden2)
        self.out = nn.Linear(2 * hidden3, num_classes)

    def forward(self, data: Data) -> torch.Tensor:
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        x, ex = self.conv1(x, edge_index, edge_attr)
        x = F.relu(x)
        x, ex = self.conv2(x, edge_index, ex)
        x = F.relu(x)
        x, _ = self.conv3(x, edge_index, ex)
        x = global_mean_pool(x, batch)     # graph-level embedding
        x = F.dropout(x, p=0.5, training=self.training)
        return self.out(x)

In [None]:
# ---------------------------
# 5) Metrics
# ---------------------------

@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, device: torch.device) -> dict:
    model.eval()
    all_logits: List[torch.Tensor] = []
    all_labels: List[torch.Tensor] = []
    for batch in loader:
        batch = batch.to(device)
        logits = model(batch)
        all_logits.append(logits.detach().cpu())
        all_labels.append(batch.y.view(-1).detach().cpu())
    logits = torch.cat(all_logits, dim=0)
    labels = torch.cat(all_labels, dim=0).long()

    preds = logits.argmax(dim=1)
    tp = ((preds == 1) & (labels == 1)).sum().item()
    tn = ((preds == 0) & (labels == 0)).sum().item()
    fp = ((preds == 1) & (labels == 0)).sum().item()
    fn = ((preds == 0) & (labels == 1)).sum().item()

    total = tp + tn + fp + fn
    acc = (tp + tn) / total if total > 0 else 0.0
    pre = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1  = (2 * pre * rec / (pre + rec)) if (pre + rec) > 0 else 0.0

    return {
        "TP": tp, "TN": tn, "FP": fp, "FN": fn,
        "Accuracy": acc, "Precision": pre, "Recall": rec, "F1": f1
    }

In [None]:
# ---------------------------
# 6) Split helpers
# ---------------------------

def stratified_indices_by_label(dataset: InMemoryDataset, test_ratio: float, seed: int) -> Tuple[List[int], List[int]]:
    """Simple stratified split by dataset.y labels at graph level."""
    g0, g1 = [], []
    for i in range(len(dataset)):
        y = int(dataset[i].y.item())
        (g1 if y == 1 else g0).append(i)
    rng = random.Random(seed)
    rng.shuffle(g0); rng.shuffle(g1)
    n0_test = max(1, int(len(g0) * test_ratio)) if len(g0) > 0 else 0
    n1_test = max(1, int(len(g1) * test_ratio)) if len(g1) > 0 else 0
    test_idx = set(g0[:n0_test] + g1[:n1_test])
    train_idx = [i for i in range(len(dataset)) if i not in test_idx]
    return train_idx, sorted(test_idx)

In [None]:
# ---------------------------
# 7) Main training
# ---------------------------

def main(cfg: TrainConfig) -> None:
    set_seed(cfg.seed)
    device = get_device()

    # W&B setup (offline by default)
    import wandb
    Path(cfg.wandb_dir).mkdir(parents=True, exist_ok=True)
    os.environ.update({
        "WANDB_MODE": cfg.wandb_mode,
        "WANDB_DIR": str(Path(cfg.wandb_dir).resolve()),
        "WANDB_DISABLE_SERVICE": "true",
        "WANDB_DISABLE_CODE": "true",
        "WANDB_DISABLE_GIT": "true",
        "WANDB_DISABLE_NOTEBOOK": "true",
        "WANDB_DISABLE_GPU": "true",
        "WANDB_HTTP_TIMEOUT": "30",
        "WANDB_INIT_TIMEOUT": "30",
    })

    # Load dataset (cached)
    dataset = EDKGDataset(root_raw=cfg.root_raw, cache_root=cfg.cache_root)
    print(f"Dataset loaded. Graphs={len(dataset)}, Num node feats={dataset.num_node_features}, Num edge feats={dataset.num_edge_features}")

    # Split
    train_idx, test_idx = stratified_indices_by_label(dataset, test_ratio=cfg.test_ratio, seed=cfg.seed)
    train_set = Subset(dataset, train_idx)
    test_set  = Subset(dataset, test_idx)
    train_loader = DataLoader(train_set, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers)
    test_loader  = DataLoader(test_set,  batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)

    # Model
    model = EDKGGCN(
        in_node=dataset.num_node_features,
        in_edge=dataset.num_edge_features,
        hidden1=cfg.hidden_size1,
        hidden2=cfg.hidden_size2,
        hidden3=cfg.hidden_size3,
        num_classes=dataset.num_classes
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    criterion = nn.CrossEntropyLoss()

    run = wandb.init(project=cfg.wandb_project, config=asdict(cfg))
    wandb.watch(model, log="all", log_freq=50)

    # optional early stopping
    best_f1 = -1.0
    best_state = None
    epochs_no_improve = 0

    for epoch in range(1, cfg.max_epochs + 1):
        model.train()
        for batch in train_loader:
            batch = batch.to(device)
            logits = model(batch)
            loss = criterion(logits, batch.y.view(-1))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        train_metrics = evaluate(model, train_loader, device)
        test_metrics  = evaluate(model, test_loader, device)

        # logging
        wandb.log({
            "loss/train_last": float(loss.detach().cpu()),
            **{f"train/{k}": v for k, v in train_metrics.items()},
            **{f"test/{k}": v for k, v in test_metrics.items()},
            "epoch": epoch
        })

        # track best by F1 on test
        if test_metrics["F1"] > best_f1:
            best_f1 = test_metrics["F1"]
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if cfg.patience and epochs_no_improve >= cfg.patience:
            print(f"Early stop at epoch {epoch} with best F1={best_f1:.4f}")
            break

    # Save artifacts in the current W&B run directory
    run_dir = Path(wandb.run.dir)
    ckpt_path = run_dir / "model_state.pt"
    meta_path = run_dir / "run_summary.json"

    if best_state is not None:
        torch.save(best_state, ckpt_path)
    else:
        torch.save(model.state_dict(), ckpt_path)

    final_train = evaluate(model, train_loader, device)
    final_test  = evaluate(model, test_loader, device)

    # W&B summary
    for k, v in {**{f"final_train/{k}": v for k, v in final_train.items()},
                 **{f"final_test/{k}": v for k, v in final_test.items()}}.items():
        wandb.summary[k] = v

    with open(meta_path, "w", encoding="utf-8") as f:
        json.dump({
            "config": asdict(cfg),
            "final_train": final_train,
            "final_test": final_test
        }, f, indent=2)

    print(f"\nSaved checkpoint to: {ckpt_path}")
    print(f"Final Test — Acc: {final_test['Accuracy']:.4f}  F1: {final_test['F1']:.4f}  "
          f"Prec: {final_test['Precision']:.4f}  Rec: {final_test['Recall']:.4f}")
    wandb.finish()

if __name__ == "__main__":
    cfg = TrainConfig()
    main(cfg)

In [None]:
import os, sys, subprocess, glob

def sync_path(p):
    env = os.environ.copy()
    env["WANDB_MODE"] = "online"
    print("Syncing:", p)
    if p.endswith(os.sep + "files"):
        p = os.path.dirname(p)
    if os.path.exists(p):
        subprocess.run([sys.executable, "-m", "wandb", "sync", p], check=False, env=env)

for p in glob.glob(r"E:\THY\EDC-AOP\GNN\wandb_runs\wandb\offline-run-*"):
    sync_path(p)

for p in glob.glob(r"E:\THY\EDC-AOP\GNN\wandb_runs\wandb\sweep-*"):
    sync_path(p)

print("All done. Check your project on W&B.")

Syncing: E:\THY\EDC-AOP\GNN\wandb_runs\wandb\offline-run-20250811_215727-t4ufj57v
Syncing: E:\THY\EDC-AOP\GNN\wandb_runs\wandb\offline-run-20250811_220201-z57gdbrv
Syncing: E:\THY\EDC-AOP\GNN\wandb_runs\wandb\offline-run-20250811_220558-bq48ollz
Syncing: E:\THY\EDC-AOP\GNN\wandb_runs\wandb\offline-run-20250811_220935-qvozpzxf
Syncing: E:\THY\EDC-AOP\GNN\wandb_runs\wandb\offline-run-20250811_221415-whzmyhto
Syncing: E:\THY\EDC-AOP\GNN\wandb_runs\wandb\offline-run-20250811_221758-xzeaeacm
Syncing: E:\THY\EDC-AOP\GNN\wandb_runs\wandb\offline-run-20250811_222153-azyt2ldt
Syncing: E:\THY\EDC-AOP\GNN\wandb_runs\wandb\offline-run-20250811_222605-t7m1ebtm
Syncing: E:\THY\EDC-AOP\GNN\wandb_runs\wandb\offline-run-20250811_223109-rru00m7x
Syncing: E:\THY\EDC-AOP\GNN\wandb_runs\wandb\offline-run-20250811_223651-sieu5a72
Syncing: E:\THY\EDC-AOP\GNN\wandb_runs\wandb\offline-run-20250811_224020-ok7e2zu2
Syncing: E:\THY\EDC-AOP\GNN\wandb_runs\wandb\offline-run-20250811_224440-q1991kpf
Syncing: E:\THY\