# Import

## Official Import

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.explain import Explainer, GNNExplainer
from torch_geometric.nn import GCNConv, global_mean_pool
import pymetis
import numpy as np
from torch_geometric.utils import to_networkx, k_hop_subgraph, subgraph
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
import networkx as nx
import yaml
from tqdm import tqdm
import time
import random
import optuna
from torch_geometric.data import Data
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize


# Config

## Device initialization

In [2]:
from utils.graph_processing import get_device

start_time = time.time()

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

device=get_device()

Using device: cuda


## Config Setting

In [None]:
from utils.graph_processing import load_config
config=load_config(config_path="config.yaml")

# config
DA_NA=config["dataset"]["name"]
DA_RO=config["dataset"]["root"]
OPLR=config["optimizer"]["lr"]
OPWE=float(config["optimizer"]["weight_decay"])
TE=config["train_model"]["epochs"]
PI=config["train_model"]["PrintI"]
HIDEEN_DIM=config["GNNModel"]["hidden_dim"]
OG=config["OG"]["node_num"]
PG=config["PG"]["node_num"]
PARTITION_KMIN=config["partition_k"]["k_min"]
PARTITION_KMAX=config["partition_k"]["k_max"]



# OriGraph

## Training

In [None]:
from utils.graph_processing import load_and_describe_dataset

# load
dataset,data=load_and_describe_dataset(dataset_name=DA_NA,root=DA_RO)

In [None]:
from models.gnn_models import GNNModel
from utils.model_utils import train_model


model = GNNModel(num_node_features=dataset.num_node_features,hidden_dim=HIDEEN_DIM,num_classes=dataset.num_classes).to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=OPLR, weight_decay=OPWE)

trained_model=train_model(model, data, optimizer, num_epochs=TE, print_interval=PI)

In [None]:
from utils.model_utils import evaluate_model

test_accuracy = evaluate_model(trained_model, data)

## Interpret On OriGraph

In [None]:
from utils.explainer_utils import calculate_fidelity

fid_positive_list = []
fid_negative_list = []
probs = []
labels = []

total_nodes = data.num_nodes if hasattr(data, 'num_nodes') else data.x.shape[0]


num_nodes_to_process = min(OG, data.num_nodes)
random_indices = random.sample(range(total_nodes), num_nodes_to_process)

for node_idx in tqdm(random_indices):
    
    fid_pos, fid_neg, prob, label = calculate_fidelity(trained_model, data, node_idx)
    if fid_pos is not None and fid_neg is not None:
        fid_positive_list.append(fid_pos)
        fid_negative_list.append(fid_neg)
        probs.append(prob)
        labels.append(label)
    else:
        print(f"Node {node_idx}: Fidelity scores are None.")

mean_fid_positive = np.mean(fid_positive_list)
mean_fid_negative = np.mean(fid_negative_list)
print(f'Fid+：{mean_fid_positive:.4f}')
print(f'Fid-：{mean_fid_negative:.4f}')


# Graph Partition

In [None]:
from utils.graph_partition import partition_graph, extract_partition_subgraph
from utils.model_utils import evaluate_proxy_model
import optuna
import numpy as np

def objective(trial):
    nparts = trial.suggest_int('nparts', PARTITION_KMIN, PARTITION_KMAX)
    num_trials = 5
    total_val_loss, total_val_acc = 0, 0

    for _ in range(num_trials):
        parts = partition_graph(data, nparts)
        selected_partition = np.random.randint(0, nparts)
        proxy_data = extract_partition_subgraph(data, parts, selected_partition)
        val_loss, val_acc = evaluate_proxy_model(proxy_data)
        total_val_loss += val_loss
        total_val_acc += val_acc
    avg_val_loss = total_val_loss / num_trials
    avg_val_acc = total_val_acc / num_trials

    if not hasattr(objective, 'best_loss') or avg_val_loss < objective.best_loss:
        objective.best_loss, objective.best_acc = avg_val_loss, avg_val_acc
        print(f"New Best: nparts = {nparts}, avg_val_loss = {avg_val_loss:.4f}, avg_val_acc = {avg_val_acc:.4f}")

    return avg_val_loss

study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=50)
best_nparts=study.best_params['nparts']


print(f"Best nparts = {study.best_params['nparts']}, Best Loss = {study.best_value:.4f}")
print(f"Best Accuracy = {objective.best_acc:.4f}")

# Graph Combining

In [None]:
from utils.graph_processing import combined_graph

combined_subgraph = combined_graph(data, best_nparts, device=device)
print(f"\nFinally Combine sum: {combined_subgraph.num_nodes}")

In [None]:
# from utils.model_utils import ProxyGNNModel
from utils.model_utils import train_proxy_model_on_combined_subgraph
from utils import model_utils
import importlib
importlib.reload(model_utils)


proxy_model, combined_subgraph = train_proxy_model_on_combined_subgraph(data, combined_subgraph, model)

# Interpret on subgraph

In [None]:
fid_positive_list, fid_negative_list, probs_list, labels_list = [], [], [], []
total_nodes = combined_subgraph.num_nodes
num_nodes_to_process = min(PG, total_nodes)
nodes_to_process = range(total_nodes) if total_nodes <= 50 else random.sample(range(total_nodes), num_nodes_to_process)

for node_idx in tqdm(nodes_to_process):
    
    fid_pos, fid_neg, prob, label = calculate_fidelity(proxy_model, data, node_idx)
    if fid_pos is not None and fid_neg is not None:
        fid_positive_list.append(fid_pos)
        fid_negative_list.append(fid_neg)
        probs_list.append(prob)
        labels_list.append(label)

    mean_fid_positive = np.mean(fid_positive_list)
    mean_fid_negative = np.mean(fid_negative_list)
    print(f'Average Fid+: {mean_fid_positive:.4f}')
    print(f'Average Fid-: {mean_fid_negative:.4f}')

    classes = np.arange(dataset.num_classes)
    labels_binarized = label_binarize(labels_list, classes=classes)
    probs_array = np.array(probs_list)

    if probs_array.shape[1] != dataset.num_classes:
        if probs_array.shape[1] < dataset.num_classes:
            probs_array = np.hstack((probs_array, np.zeros((probs_array.shape[0], dataset.num_classes - probs_array.shape[1]))))
        else:
            probs_array = probs_array[:, :dataset.num_classes]

    unique_labels = np.unique(labels_list)
    if len(unique_labels) < 2:
        print("Warning: Only one class present in predictions. AUC-ROC cannot be calculated.")
    else:
        auc_roc = (
            roc_auc_score(labels_binarized, probs_array[:, 1]) if dataset.num_classes == 2
            else roc_auc_score(labels_binarized, probs_array, average='macro', multi_class='ovr')
        )
        print(f'AUC-ROC: {auc_roc:.4f}' if auc_roc is not None else "AUC-ROC cannot be calculated.")