{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Exploring Graph Agentic Network\n",
    "\n",
    "This notebook demonstrates how to use the Graph Agentic Network framework for node classification tasks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import networkx as nx\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "# Add parent directory to path\n",
    "sys.path.append('..')\n",
    "\n",
    "# Import GAN modules\n",
    "import config\n",
    "from gan.llm import MockLLMInterface\n",
    "from gan.graph import GraphAgenticNetwork\n",
    "from data.dataset import load_or_create_dataset\n",
    "from gan.utils import seed_everything, visualize_graph, evaluate_node_classification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Check GPU Availability"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check if CUDA is available\n",
    "if torch.cuda.is_available():\n",
    "    print(f\"CUDA is available with {torch.cuda.device_count()} device(s)\")\n",
    "    for i in range(torch.cuda.device_count()):\n",
    "        print(f\"  Device {i}: {torch.cuda.get_device_name(i)}\")\n",
    "    \n",
    "    # Set default device\n",
    "    device = torch.device(\"cuda:0\")\n",
    "else:\n",
    "    print(\"CUDA is not available, using CPU\")\n",
    "    device = torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Load a Dataset\n",
    "\n",
    "For quick experimentation, we'll use a small subgraph of the OGB-Arxiv dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set random seed for reproducibility\n",
    "seed_everything(42)\n",
    "\n",
    "# Load dataset\n",
    "use_subgraph = True  # Use a smaller graph for faster processing\n",
    "subgraph_size = 500  # Number of nodes in the subgraph\n",
    "\n",
    "dataset = load_or_create_dataset(\n",
    "    'ogbn-arxiv',\n",
    "    use_subgraph=use_subgraph,\n",
    "    subgraph_size=subgraph_size\n",
    ")\n",
    "\n",
    "# Extract dataset components\n",
    "adj_matrix = dataset['adj_matrix']\n",
    "node_features = dataset['node_features']\n",
    "labels = dataset['labels']\n",
    "train_idx = dataset['train_idx']\n",
    "val_idx = dataset['val_idx']\n",
    "test_idx = dataset['test_idx']\n",
    "num_classes = dataset['num_classes']\n",
    "\n",
    "print(f\"Loaded {'subgraph of ' if use_subgraph else ''}OGB-Arxiv dataset\")\n",
    "print(f\"  Nodes: {adj_matrix.shape[0]}\")\n",
    "print(f\"  Edges: {adj_matrix.sum().item() / 2:.0f}\")\n",
    "print(f\"  Features: {node_features.shape[1]}\")\n",
    "print(f\"  Classes: {num_classes}\")\n",
    "print(f\"  Train/Val/Test split: {len(train_idx)}/{len(val_idx)}/{len(test_idx)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Visualize the Graph\n",
    "\n",
    "Let's visualize a portion of the graph to understand its structure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For large graphs, visualize only a small section\n",
    "max_nodes_to_plot = 100\n",
    "\n",
    "if adj_matrix.shape[0] > max_nodes_to_plot:\n",
    "    # Plot a smaller section\n",
    "    subset_nodes = torch.randperm(adj_matrix.shape[0])[:max_nodes_to_plot]\n",
    "    sub_adj = adj_matrix[subset_nodes][:, subset_nodes]\n",
    "    sub_labels = labels[subset_nodes]\n",
    "    title = f\"Subset of {max_nodes_to_plot} nodes\"\n",
    "else:\n",
    "    # Plot the whole graph\n",
    "    sub_adj = adj_matrix\n",
    "    sub_labels = labels\n",
    "    title = \"Full graph\"\n",
    "\n",
    "# Visualize with colors based on labels\n",
    "visualize_graph(\n",
    "    sub_adj, \n",
    "    node_colors=sub_labels.numpy(), \n",
    "    title=title\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Create a Graph Agentic Network\n",
    "\n",
    "For the purpose of this demo, we'll use a mock LLM interface to avoid requiring an actual LLM connection."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize mock LLM interface\n",
    "llm_interface = MockLLMInterface()\n",
    "\n",
    "# Create GAN model\n",
    "num_layers = 2\n",
    "gan = GraphAgenticNetwork(\n",
    "    adj_matrix=adj_matrix,\n",
    "    node_features=node_features,\n",
    "    llm_interface=llm_interface,\n",
    "    labels=labels,\n",
    "    num_layers=num_layers\n",
    ")\n",
    "\n",
    "print(f\"Created Graph Agentic Network with {num_layers} layers\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Run the Graph Agentic Network\n",
    "\n",
    "Now let's run the GAN on our dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run GAN\n",
    "batch_size = 50  # Process nodes in batches for larger graphs\n",
    "print(f\"Running Graph Agentic Network...\")\n",
    "gan.forward(batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Evaluate Results\n",
    "\n",
    "Let's evaluate the performance of our GAN model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get predictions\n",
    "predictions = gan.get_node_predictions()\n",
    "\n",
    "# Evaluate results\n",
    "train_metrics = evaluate_node_classification(predictions[train_idx], labels[train_idx])\n",
    "val_metrics = evaluate_node_classification(predictions[val_idx], labels[val_idx])\n",
    "test_metrics = evaluate_node_classification(predictions[test_idx], labels[test_idx])\n",
    "\n",
    "print(\"GAN Results:\")\n",
    "print(f\"  Train Accuracy: {train_metrics['accuracy']:.4f}\")\n",
    "print(f\"  Val Accuracy: {val_metrics['accuracy']:.4f}\")\n",
    "print(f\"  Test Accuracy: {test_metrics['accuracy']:.4f}\")\n",
    "print(f\"  Test F1 (Micro): {test_metrics['f1_micro']:.4f}\")\n",
    "print(f\"  Test F1 (Macro): {test_metrics['f1_macro']:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Analyze Node Actions\n",
    "\n",
    "Let's analyze what actions the nodes took during the GAN processing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get node memory (action history)\n",
    "node_memory = gan.get_node_memory()\n",
    "\n",
    "# Count actions by type\n",
    "action_counts = {}\n",
    "for node_id, memory in node_memory.items():\n",
    "    for entry in memory:\n",
    "        action = entry['result']['action']\n",
    "        if action not in action_counts:\n",
    "            action_counts[action] = 0\n",
    "        action_counts[action] += 1\n",
    "\n",
    "# Plot action distribution\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.bar(action_counts.keys(), action_counts.values())\n",
    "plt.title('Action Type Distribution')\n",
    "plt.xlabel('Action Type')\n",
    "plt.ylabel('Count')\n",
    "\n",
    "for i, (action, count) in enumerate(action_counts.items()):\n",
    "    plt.text(i, count + 0.5, str(count), ha='center')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. Examine Individual Node Behavior\n",
    "\n",
    "Let's examine the behavior of individual nodes in more detail."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Select a random node to examine\n",
    "node_id = np.random.randint(0, adj_matrix.shape[0])\n",
    "node_history = node_memory.get(node_id, [])\n",
    "\n",
    "print(f\"Examining behavior of Node {node_id}\")\n",
    "print(f\"True label: {labels[node_id].item()}\")\n",
    "print(f\"Predicted label: {predictions[node_id].item() if node_id < len(predictions) else 'N/A'}\")\n",
    "print(f\"Number of actions: {len(node_history)}\\n\")\n",
    "\n",
    "for i, entry in enumerate(node_history):\n",
    "    print(f\"Action {i+1} (Layer {entry['layer']})\")\n",
    "    print(f\"  Type: {entry['result']['action']}\")\n",
    "    print(f\"  Details: {entry['result']}\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9. Next Steps\n",
    "\n",
    "To extend this exploration, you could:\n",
    "\n",
    "1. Replace the MockLLMInterface with a real LLM connection\n",
    "2. Try different datasets or larger subgraphs\n",
    "3. Experiment with different numbers of layers and batch sizes\n",
    "4. Implement custom node agent behaviors\n",
    "5. Compare with traditional GNN baselines"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}

In [1]:
import sys
import os

# 自动添加项目根路径
sys.path.append(os.path.abspath(".."))


In [2]:
from data.dataset import load_or_create_dataset
dataset = load_or_create_dataset("ogbn-arxiv", use_subgraph=True, subgraph_size=100)

for key, value in dataset.items():
    print(f"{key}: {type(value)}, shape: {getattr(value, 'shape', 'N/A')}")


Loading OGB-Arxiv dataset from /common/home/mg1998/Graph/GAN/Graph_Agentic_Network/data/ogbn-arxiv...


  self.data, self.slices = torch.load(self.processed_paths[0])


: 