In [1]:
!pip install ogb



In [2]:
pip install torch_geometric

Note: you may need to restart the kernel to use updated packages.


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from ogb.nodeproppred import PygNodePropPredDataset
import functools

# --- 【緊急パッチ】PyTorch 2.6のセキュリティ制限を解除 ---
_original_load = torch.load

def unsafe_load(*args, **kwargs):
    # 強制的にセキュリティチェックを無効化 (weights_only=False)
    if 'weights_only' not in kwargs:
        kwargs['weights_only'] = False
    return _original_load(*args, **kwargs)

torch.load = unsafe_load
# -------------------------------------------------------

print("データをロード中（パッチ適用済み）...")
dataset = PygNodePropPredDataset(name='ogbn-arxiv', root='./data')
data = dataset[0]

# パッチを戻す（念のため）
torch.load = _original_load

print("--------------------------------")
print("【セットアップ完了】")
print(f"ノード数: {data.num_nodes}")
print(f"特徴量次元: {data.x.shape[1]}")
print("これでNeural ODEの実装に進めます。")
print("--------------------------------")

データをロード中（パッチ適用済み）...
--------------------------------
【セットアップ完了】
ノード数: 169343
特徴量次元: 128
これでNeural ODEの実装に進めます。
--------------------------------


In [11]:
import pandas as pd

# 1. データ全体の要約を見る
print("=== データの全体像 ===")
print(data) 

# 2. 特徴量（x）の中身を少しだけ見る
print("\n=== 論文の特徴量（最初の5件） ===")
# 128個の数字の並びが、1つの論文の意味を表しています
df_x = pd.DataFrame(data.x[:5].numpy())
print(df_x)

# 3. 引用関係（edge_index）の中身を見る
print("\n=== 引用関係（最初の5リンク） ===")
# [引用元ID, 引用先ID] のペアです
edges = data.edge_index.t()[:]
df_edges = pd.DataFrame(edges.numpy(), columns=["引用元ID", "引用先ID"])
print(df_edges)

# 4. 年データを見る
print("\n=== 出版年（最初の5件） ===")
print(data.node_year[:5].flatten())

=== データの全体像 ===
Data(num_nodes=169343, edge_index=[2, 1166243], x=[169343, 128], node_year=[169343, 1], y=[169343, 1])

=== 論文の特徴量（最初の5件） ===
        0         1         2         3         4         5         6    \
0 -0.057943 -0.052530 -0.072603 -0.026555  0.130435 -0.241386 -0.449242   
1 -0.124500 -0.070665 -0.325202  0.007779 -0.001559  0.074189 -0.191013   
2 -0.080242 -0.023328 -0.183787 -0.180707  0.075765 -0.125818 -0.394573   
3 -0.145044  0.054915 -0.126666  0.039971 -0.055909 -0.101278 -0.339202   
4 -0.071154  0.070766 -0.281432 -0.161892 -0.165246 -0.029116 -0.338593   

        7         8         9    ...       118       119       120       121  \
0 -0.018443 -0.087218  0.112320  ...  0.211490 -0.226118 -0.185603  0.053230   
1  0.049689  0.026369  0.099364  ...  0.106316  0.052926 -0.258378  0.021567   
2 -0.219078 -0.108931  0.056966  ...  0.019453 -0.070291 -0.177562 -0.214012   
3 -0.115801 -0.080058 -0.001633  ... -0.065752  0.042735  0.066338 -0.226921   
4 -0.13

In [13]:
df_edges

Unnamed: 0,引用元ID,引用先ID
0,104447,13091
1,15858,47283
2,107156,69161
3,107156,136440
4,107156,107366
...,...,...
1166238,45118,79124
1166239,45118,147994
1166240,45118,162473
1166241,45118,162537


In [None]:
import os

# --- OGBマッピングの動的ロード (再現用) ---
# 実際のデータが読み込まれていることを確認します。
root_dir = './data'
mapping_path = os.path.join(root_dir, 'ogbn_arxiv/mapping/labelidx2arxivcategeory.csv.gz')

try:
    label_df = pd.read_csv(mapping_path)
    # IDをキー、カテゴリコードを値とする辞書を作成
    ARXIV_CATEGORY_NAMES = dict(zip(label_df['label idx'], label_df['arxiv category']))
except Exception:
    # 失敗した場合はダミーデータを使用（元のコードのロジックを再現）
    ARXIV_CATEGORY_NAMES = {i: f'Dummy_Cat_{i}' for i in range(40)}
# ----------------------------------------


# IDとカテゴリコードのペアを抽出し、データフレームを作成
category_data = []
# 0から39までのIDを明示的に処理
for cat_id in range(40):
    category_code = ARXIV_CATEGORY_NAMES.get(cat_id, 'N/A')
    category_data.append([cat_id, category_code])

# DataFrameの作成
df_categories = pd.DataFrame(category_data, columns=['ID', 'Arxiv Category Code'])

print("=== OGBN-Arxiv カテゴリ ID (0-39) 対照表 ===")
df_categories



=== OGBN-Arxiv カテゴリ ID (0-39) 対照表 ===


  from .autonotebook import tqdm as notebook_tqdm


Unnamed: 0,ID,Arxiv Category Code
0,0,arxiv cs na
1,1,arxiv cs mm
2,2,arxiv cs lo
3,3,arxiv cs cy
4,4,arxiv cs cr
5,5,arxiv cs dc
6,6,arxiv cs hc
7,7,arxiv cs ce
8,8,arxiv cs ni
9,9,arxiv cs cc


In [None]:
import os
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch_geometric.nn import GCNConv, GATConv, SAGEConv
from torch_geometric.data import Data
from torch_geometric.utils import negative_sampling
from torchdiffeq import odeint_adjoint as odeint
from sklearn.metrics import roc_auc_score, average_precision_score
import umap
from scipy.interpolate import griddata
import networkx as nx
from collections import defaultdict
import random
import sys 
import csv 
from contextlib import nullcontext 

# --- グローバル変数初期化 ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# PDFおよびPNGファイルをまとめるディレクトリ作成
OUTPUT_DIR = "visualizations_output"
os.makedirs(OUTPUT_DIR, exist_ok=True) 
print(f"✓ Output directory created: {OUTPUT_DIR}")

# OGBロード時の 'maximum recursion depth exceeded' エラー対策
sys.setrecursionlimit(3000) 

# PyTorch 2.6のセキュリティ制限解除パッチ
_original_load = torch.load
def unsafe_load(*args, **kwargs):
    if 'weights_only' not in kwargs:
        kwargs['weights_only'] = False
    return _original_load(*args, **kwargs)
torch.load = unsafe_load

# --- OGBマッピングの動的ロード ---
try:
    from ogb.nodeproppred import PygNodePropPredDataset
    # マッピングパス取得のため一時的にインスタンス化を試みる
    try:
        dataset = PygNodePropPredDataset(name='ogbn-arxiv', root='./data')
        mapping_path = os.path.join(dataset.root, 'mapping', 'labelidx2arxivcategeory.csv.gz')
        label_df = pd.read_csv(mapping_path)
        ARXIV_CATEGORY_NAMES = dict(zip(label_df['label idx'], label_df['arxiv category']))
        print(f"\n✓ OGB-Arxiv Categories Loaded: {len(ARXIV_CATEGORY_NAMES)} unique labels.")
    except Exception as map_e:
        raise Exception(f"Failed to load map after dataset instance: {map_e}")

except Exception as e:
    print(f"⚠️ OGB Category Mapping Load Error: {e}")
    print("Using DUMMY Category Names.")
    ARXIV_CATEGORY_NAMES = {i: f'Dummy_Cat_{i}' for i in range(40)}

# ----------------------------------------------------


# ==========================================
# 1. Universal Data Adapter
# ==========================================
class UniversalDataFactory:
    def __init__(self, dataset_name, root_dir='./data'):
        self.dataset_name = dataset_name
        self.root_dir = root_dir
        
    def load_data(self):
        print(f"Loading dataset: {self.dataset_name}...")
        try:
            from ogb.nodeproppred import PygNodePropPredDataset
            dataset = PygNodePropPredDataset(name='ogbn-arxiv', root=self.root_dir)
            data = dataset[0]
            print(f"✓ Loaded OGB dataset successfully")
        except Exception as e:
            print(f"OGB Load Error: {e}. Creating DUMMY data for demo.")
            num_nodes = 2000
            edge_index = torch.randint(0, num_nodes, (2, 10000))
            x = torch.randn(num_nodes, 128)
            y = torch.randint(0, 40, (num_nodes,))
            node_year = torch.randint(2015, 2021, (num_nodes, 1))
            data = Data(x=x, edge_index=edge_index, y=y.unsqueeze(1), node_year=node_year)

        num_papers = data.num_nodes
        categories = data.y.numpy().flatten()
        df_nodes = pd.DataFrame({
            'node_id': range(num_papers),
            'year': data.node_year.numpy().flatten(),
            'category': categories
        })
        edge_index = data.edge_index.numpy()
        df_edges = pd.DataFrame({'source': edge_index[0], 'target': edge_index[1]})
        num_categories = len(np.unique(categories))
        
        print(f"✓ Data loaded: {num_papers} papers, {num_categories} categories.")
        return {
            'df_nodes': df_nodes,
            'df_edges': df_edges,
            'node_features': data.x,
            'num_categories': num_categories
        }

class DynamicGraphBuilder:
    def __init__(self, adapter_output):
        self.data = adapter_output
        
    def build_snapshots(self):
        df_nodes = self.data['df_nodes']
        df_edges = self.data['df_edges']
        feats = self.data['node_features']
        years = sorted(df_nodes['year'].unique())
        snapshots = {}
        num_cats = self.data['num_categories']
        offset = num_cats
        
        print("Building temporal snapshots...")
        
        for year in years:
            if year < 2015:
                continue
                
            active_papers = df_nodes[df_nodes['year'] <= year]
            active_ids = active_papers['node_id'].values
            valid_edges = df_edges[
                (df_edges['source'].isin(active_ids)) & 
                (df_edges['target'].isin(active_ids))
            ]
            
            # カテゴリ→論文のエッジ
            cat_src = torch.tensor(active_papers['category'].values, dtype=torch.long)
            paper_dst = torch.tensor(active_papers['node_id'].values + offset, dtype=torch.long)
            
            # 論文→論文のエッジ
            src_paper = torch.tensor(valid_edges['source'].values + offset, dtype=torch.long)
            dst_paper = torch.tensor(valid_edges['target'].values + offset, dtype=torch.long)
            
            edge_index = torch.cat([
                torch.stack([cat_src, paper_dst], dim=0),
                torch.stack([src_paper, dst_paper], dim=0)
            ], dim=1)
            
            # ラベル設定
            total_nodes = offset + len(df_nodes)
            full_y = torch.full((total_nodes,), -1, dtype=torch.long)
            full_y[:num_cats] = torch.arange(num_cats)
            paper_indices = df_nodes['node_id'].values + offset
            paper_cats = df_nodes['category'].values
            full_y[paper_indices] = torch.tensor(paper_cats, dtype=torch.long)
            
            snapshots[year] = Data(
                x=feats,
                edge_index=edge_index,
                num_nodes=total_nodes,
                y=full_y
            )
            
        print(f"✓ Built {len(snapshots)} snapshots.")
        return snapshots, num_cats, feats.shape[1]

# ==========================================
# 2. Advanced Negative Sampling Strategies
# ==========================================
class NegativeSampler:
    """複数のネガティブサンプリング戦略を提供"""
    
    @staticmethod
    def random_negative_sampling(pos_edge_index, num_nodes, num_samples):
        return negative_sampling(pos_edge_index, num_nodes=num_nodes, 
                                num_neg_samples=num_samples)
    
    @staticmethod
    def historical_negative_sampling(train_edges_history, current_pos_edges, 
                                     num_nodes, num_samples):
        historical_set = set()
        for edges in train_edges_history:
            src, dst = edges[0].cpu().numpy(), edges[1].cpu().numpy()
            historical_set.update(zip(src, dst))
        
        current_src, current_dst = current_pos_edges[0].cpu().numpy(), current_pos_edges[1].cpu().numpy()
        current_set = set(zip(current_src, current_dst))
        
        historical_negatives = list(historical_set - current_set)
        
        if len(historical_negatives) < num_samples:
            extra_needed = num_samples - len(historical_negatives)
            random_neg = negative_sampling(current_pos_edges, num_nodes=num_nodes,
                                          num_neg_samples=extra_needed)
            
            if len(historical_negatives) > 0:
                hist_neg_array = np.array(historical_negatives).T
                hist_neg_tensor = torch.tensor(hist_neg_array, dtype=torch.long)
                return torch.cat([hist_neg_tensor, random_neg], dim=1)
            else:
                return random_neg
        else:
            sampled_negatives = random.sample(historical_negatives, num_samples)
            neg_array = np.array(sampled_negatives).T
            return torch.tensor(neg_array, dtype=torch.long)
    
    @staticmethod
    def inductive_negative_sampling(test_history, current_test_edges, 
                                    num_nodes, num_samples):
        test_historical_set = set()
        for edges in test_history:
            src, dst = edges[0].cpu().numpy(), edges[1].cpu().numpy()
            test_historical_set.update(zip(src, dst))
        
        current_src, current_dst = current_test_edges[0].cpu().numpy(), current_test_edges[1].cpu().numpy()
        current_set = set(zip(current_src, current_dst))
        
        inductive_negatives = list(test_historical_set - current_set)
        
        if len(inductive_negatives) < num_samples:
            extra_needed = num_samples - len(inductive_negatives)
            random_neg = negative_sampling(current_test_edges, num_nodes=num_nodes,
                                          num_neg_samples=extra_needed)
            
            if len(inductive_negatives) > 0:
                ind_neg_array = np.array(inductive_negatives).T
                ind_neg_tensor = torch.tensor(ind_neg_array, dtype=torch.long)
                return torch.cat([ind_neg_tensor, random_neg], dim=1)
            else:
                return random_neg
        else:
            sampled_negatives = random.sample(inductive_negatives, num_samples)
            neg_array = np.array(sampled_negatives).T
            return torch.tensor(neg_array, dtype=torch.long)

# ==========================================
# 3. Baseline Models 
# ==========================================

class StaticGCN(nn.Module):
    def __init__(self, num_cats, feat_dim, hidden_dim):
        super(StaticGCN, self).__init__()
        self.num_cats = num_cats
        self.cat_emb = nn.Embedding(num_cats, hidden_dim)
        self.feat_encoder = nn.Linear(feat_dim, hidden_dim)
        self.gcn1 = GCNConv(hidden_dim, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim)
        self.link_decoder = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
    def encode(self, x, edge_index, num_total_nodes):
        h_paper = self.feat_encoder(x)
        h_cat = self.cat_emb.weight
        h_all = torch.zeros(num_total_nodes, h_paper.size(1), device=x.device)
        h_all[:self.num_cats] = h_cat
        h_all[self.num_cats:self.num_cats+h_paper.size(0)] = h_paper
        
        h = F.relu(self.gcn1(h_all, edge_index))
        h = self.gcn2(h, edge_index)
        return h
    
    def predict_link(self, z, edge_index):
        src, dst = edge_index
        cat_feat = torch.cat([z[src], z[dst]], dim=-1)
        return torch.sigmoid(self.link_decoder(cat_feat)).view(-1)

class GCN_LSTM(nn.Module):
    def __init__(self, num_cats, feat_dim, hidden_dim):
        super(GCN_LSTM, self).__init__()
        self.num_cats = num_cats
        self.cat_emb = nn.Embedding(num_cats, hidden_dim)
        self.feat_encoder = nn.Linear(feat_dim, hidden_dim)
        self.gcn1 = GCNConv(hidden_dim, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim)
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
        self.link_decoder = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
    def encode_snapshot(self, x, edge_index, num_total_nodes):
        h_paper = self.feat_encoder(x)
        h_cat = self.cat_emb.weight
        h_all = torch.zeros(num_total_nodes, h_paper.size(1), device=x.device)
        h_all[:self.num_cats] = h_cat
        h_all[self.num_cats:self.num_cats+h_paper.size(0)] = h_paper
        
        h = F.relu(self.gcn1(h_all, edge_index))
        h = self.gcn2(h, edge_index)
        return h
    
    def forward(self, snapshots_data):
        embeddings = []
        for data in snapshots_data:
            z = self.encode_snapshot(data.x, data.edge_index, data.num_nodes)
            embeddings.append(z)
        
        if len(embeddings) > 1:
            emb_stack = torch.stack(embeddings, dim=1)
            lstm_out, _ = self.lstm(emb_stack)
            return lstm_out[:, -1, :]
        else:
            return embeddings[0]
    
    def predict_link(self, z, edge_index):
        src, dst = edge_index
        cat_feat = torch.cat([z[src], z[dst]], dim=-1)
        return torch.sigmoid(self.link_decoder(cat_feat)).view(-1)

class GraphSAGEModel(nn.Module):
    def __init__(self, num_cats, feat_dim, hidden_dim):
        super(GraphSAGEModel, self).__init__()
        self.num_cats = num_cats
        self.cat_emb = nn.Embedding(num_cats, hidden_dim)
        self.feat_encoder = nn.Linear(feat_dim, hidden_dim)
        self.sage1 = SAGEConv(hidden_dim, hidden_dim)
        self.sage2 = SAGEConv(hidden_dim, hidden_dim)
        self.link_decoder = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
    def encode(self, x, edge_index, num_total_nodes):
        h_paper = self.feat_encoder(x)
        h_cat = self.cat_emb.weight
        h_all = torch.zeros(num_total_nodes, h_paper.size(1), device=x.device)
        h_all[:self.num_cats] = h_cat
        h_all[self.num_cats:self.num_cats+h_paper.size(0)] = h_paper
        
        h = F.relu(self.sage1(h_all, edge_index))
        h = self.sage2(h, edge_index)
        return h
    
    def predict_link(self, z, edge_index):
        src, dst = edge_index
        cat_feat = torch.cat([z[src], z[dst]], dim=-1)
        return torch.sigmoid(self.link_decoder(cat_feat)).view(-1)

class SimpleMemoryGNN(nn.Module):
    def __init__(self, num_cats, feat_dim, hidden_dim):
        super(SimpleMemoryGNN, self).__init__()
        self.num_cats = num_cats
        self.cat_emb = nn.Embedding(num_cats, hidden_dim)
        self.feat_encoder = nn.Linear(feat_dim, hidden_dim)
        self.gcn = GCNConv(hidden_dim, hidden_dim)
        self.memory_dim = hidden_dim
        self.memory_updater = nn.GRUCell(hidden_dim, hidden_dim)
        self.link_decoder = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        self.memory = None
        
    def init_memory(self, num_total_nodes, device):
        self.memory = torch.zeros(num_total_nodes, self.memory_dim, device=device)
        
    def encode(self, x, edge_index, num_total_nodes):
        h_paper = self.feat_encoder(x)
        h_cat = self.cat_emb.weight
        h_all = torch.zeros(num_total_nodes, h_paper.size(1), device=x.device)
        h_all[:self.num_cats] = h_cat
        h_all[self.num_cats:self.num_cats+h_paper.size(0)] = h_paper
        
        h = F.relu(self.gcn(h_all, edge_index))
        h_with_memory = self.memory_updater(h, self.memory)
        self.memory = h_with_memory.detach()
        
        return h_with_memory
    
    def predict_link(self, z, edge_index):
        src, dst = edge_index
        cat_feat = torch.cat([z[src], z[dst]], dim=-1)
        return torch.sigmoid(self.link_decoder(cat_feat)).view(-1)

# --- GRAND: Graph Neural Diffusion (Section 3.1) ---
class GRANDFunc(nn.Module):
    def __init__(self, gnn_layer):
        super(GRANDFunc, self).__init__()
        self.gnn = gnn_layer
        self.edge_index = None

    def set_graph_structure(self, edge_index):
        self.edge_index = edge_index

    def forward(self, t, x):
        # 拡散方程式: dH/dt = GNN(H) - H
        ax = self.gnn(x, self.edge_index)
        return ax - x 

class GRAND_ODE(nn.Module):
    def __init__(self, num_cats, feat_dim, hidden_dim):
        super(GRAND_ODE, self).__init__()
        self.num_cats = num_cats
        self.cat_emb = nn.Embedding(num_cats, hidden_dim)
        self.feat_encoder = nn.Linear(feat_dim, hidden_dim)
        # GRANDはAttentionによる拡散が特徴
        self.gnn_layer = GATConv(hidden_dim, hidden_dim, heads=1, concat=False)
        self.ode_func = GRANDFunc(self.gnn_layer)
        self.link_decoder = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1)
        )

    def encode(self, x, num_total_nodes):
        h_paper = self.feat_encoder(x)
        h_cat = self.cat_emb.weight
        h_all = torch.zeros(num_total_nodes, h_paper.size(1), device=x.device)
        h_all[:self.num_cats] = h_cat
        h_all[self.num_cats:self.num_cats+h_paper.size(0)] = h_paper
        return h_all

    def forward(self, x, edge_index, t_span, num_total_nodes):
        z0 = self.encode(x, num_total_nodes)
        self.ode_func.set_graph_structure(edge_index)
        return odeint(self.ode_func, z0, t_span, method='dopri5')

    def predict_link(self, z, edge_index):
        src, dst = edge_index
        cat_feat = torch.cat([z[src], z[dst]], dim=-1)
        return torch.sigmoid(self.link_decoder(cat_feat)).view(-1)

# --- GREAD: Graph Neural Reaction-Diffusion (Section 3.2) ---
class GREADFunc(nn.Module):
    def __init__(self, hidden_dim, gnn_layer):
        super(GREADFunc, self).__init__()
        self.gnn = gnn_layer
        self.reaction = nn.Linear(hidden_dim, hidden_dim)
        self.edge_index = None

    def set_graph_structure(self, edge_index):
        self.edge_index = edge_index

    def forward(self, t, x):
        # 反応拡散: dH/dt = Diffusion(H) + Reaction(H)
        diffusion = self.gnn(x, self.edge_index) - x
        reaction = torch.tanh(self.reaction(x))
        return diffusion + reaction

class GREAD_ODE(nn.Module):
    def __init__(self, num_cats, feat_dim, hidden_dim):
        super(GREAD_ODE, self).__init__()
        self.num_cats = num_cats
        self.cat_emb = nn.Embedding(num_cats, hidden_dim)
        self.feat_encoder = nn.Linear(feat_dim, hidden_dim)
        self.gnn_layer = GCNConv(hidden_dim, hidden_dim)
        self.ode_func = GREADFunc(hidden_dim, self.gnn_layer)
        self.link_decoder = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1)
        )

    def encode(self, x, num_total_nodes):
        h_paper = self.feat_encoder(x)
        h_cat = self.cat_emb.weight
        h_all = torch.zeros(num_total_nodes, h_paper.size(1), device=x.device)
        h_all[:self.num_cats] = h_cat
        h_all[self.num_cats:self.num_cats+h_paper.size(0)] = h_paper
        return h_all

    def forward(self, x, edge_index, t_span, num_total_nodes):
        z0 = self.encode(x, num_total_nodes)
        self.ode_func.set_graph_structure(edge_index)
        return odeint(self.ode_func, z0, t_span, method='dopri5')
    
    def predict_link(self, z, edge_index):
        src, dst = edge_index
        cat_feat = torch.cat([z[src], z[dst]], dim=-1)
        return torch.sigmoid(self.link_decoder(cat_feat)).view(-1)

# ==========================================
# 4. InnoVeloODE: 2nd Order Implementation (提案手法)
# ==========================================

class SecondOrderODEFunc(nn.Module):
    def __init__(self, hidden_dim, gnn_layer, ablation_mode=None):
        super(SecondOrderODEFunc, self).__init__()
        self.gnn = gnn_layer
        self.ablation_mode = ablation_mode # 'no_velocity', 'no_decay', 'adaptive_decay'
        self.hidden_dim = hidden_dim
        
        # Adaptive Decay: 減衰係数をノードの状態から学習
        self.adaptive_decay_net = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Sigmoid()
        )
        self.fixed_damping = nn.Parameter(torch.tensor(0.5))
        self.edge_index = None
        
        # 可視化用コンポーネント保存
        self._store_components = False
        self.components = {}

    def set_graph_structure(self, edge_index):
        self.edge_index = edge_index

    def forward(self, t, z_augmented):
        # z_augmented = [Position H, Velocity V]
        dim = z_augmented.shape[1] // 2
        h = z_augmented[:, :dim]
        v = z_augmented[:, dim:]

        # 1. Diffusion Force (復元力): GNN output as target
        target_h = self.gnn(h, self.edge_index)
        diffusion_force = target_h - h 

        # 2. Decay/Damping Force (減衰力)
        if self.ablation_mode == 'no_decay':
            decay_force = torch.zeros_like(v)
        elif self.ablation_mode == 'adaptive_decay':
            gamma = self.adaptive_decay_net(h)
            decay_force = -gamma * v
        else: # linear_decay or default
            decay_force = -self.fixed_damping * v

        # 支配方程式: dH/dt = V, dV/dt = Force
        dh_dt = v
        
        if self.ablation_mode == 'no_velocity':
            # 1階ODEとして動作 (GRAND相当)
            dh_dt = diffusion_force
            dv_dt = torch.zeros_like(v)
        else:
            dv_dt = diffusion_force + decay_force

        if self._store_components:
            self.components = {
                'dz_dt': dh_dt, # 便宜上、位置の変化速度を保存
                'velocity': v,
                'diffusion': diffusion_force,
                'decay': decay_force
            }

        return torch.cat([dh_dt, dv_dt], dim=1)
    
    def set_component_storage(self, status):
        self._store_components = status
    def get_components(self):
        return self.components

class InnoVeloODE(nn.Module):
    def __init__(self, num_cats, feat_dim, hidden_dim, use_gat=False, gat_heads=2, ablation_mode='adaptive_decay'):
        super(InnoVeloODE, self).__init__()
        self.num_cats = num_cats
        self.hidden_dim = hidden_dim
        self.cat_emb = nn.Embedding(num_cats, hidden_dim)
        self.feat_encoder = nn.Linear(feat_dim, hidden_dim)
        
        # 初期速度 V0 の推定
        self.init_velocity = nn.Linear(hidden_dim, hidden_dim)

        if use_gat:
            self.gnn_layer = GATConv(hidden_dim, hidden_dim // gat_heads, heads=gat_heads, concat=True)
            print(f"✓ Using GAT ({gat_heads} heads) - Mode: {ablation_mode}")
        else:
            self.gnn_layer = GCNConv(hidden_dim, hidden_dim)
            print(f"✓ Using GCN - Mode: {ablation_mode}")
        
        self.ode_func = SecondOrderODEFunc(hidden_dim, self.gnn_layer, ablation_mode=ablation_mode)
        
        self.link_decoder = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1)
        )

    def encode(self, x, num_total_nodes):
        h_paper = self.feat_encoder(x)
        h_cat = self.cat_emb.weight
        h_all = torch.zeros(num_total_nodes, h_paper.size(1), device=x.device)
        h_all[:self.num_cats] = h_cat
        h_all[self.num_cats:self.num_cats+h_paper.size(0)] = h_paper
        
        # 初期速度の推定 (トレンドの勢い)
        v0 = torch.tanh(self.init_velocity(h_all))
        return torch.cat([h_all, v0], dim=1) # [H, V]

    def forward(self, x, edge_index, t_span, num_total_nodes):
        z0_augmented = self.encode(x, num_total_nodes)
        self.ode_func.set_graph_structure(edge_index)
        
        # ODE積分
        z_traj = odeint(self.ode_func, z0_augmented, t_span, method='rk4', rtol=1e-3, atol=1e-3)
        return z_traj

    def predict_link(self, z_augmented, edge_index):
        # リンク予測には位置情報 H (前半部分) のみを使用
        h = z_augmented[:, :self.hidden_dim]
        src, dst = edge_index
        cat_feat = torch.cat([h[src], h[dst]], dim=-1)
        return torch.sigmoid(self.link_decoder(cat_feat)).view(-1)
    
# ==========================================
# 5. Enhanced Evaluation Functions
# ==========================================
def evaluate_link_prediction_enhanced(model, data_t0, data_t1, device, 
                                     sampling_strategy='random',
                                     train_history=None, test_history=None,
                                     model_name="Model"):
    """拡張評価関数"""
    
    model.eval()
    with torch.no_grad():
        data_t0 = data_t0.to(device)
        data_t1 = data_t1.to(device)
        
        if isinstance(model, (InnoVeloODE, GRAND_ODE, GREAD_ODE)):
            t_span = torch.tensor([0.0, 1.0]).to(device)
            z_traj = model(data_t0.x, data_t0.edge_index, t_span, data_t0.num_nodes)
            z_pred = z_traj[-1]
        elif isinstance(model, GCN_LSTM):
            z_pred = model.encode_snapshot(data_t0.x, data_t0.edge_index, data_t0.num_nodes)
        elif isinstance(model, SimpleMemoryGNN):
            z_pred = model.encode(data_t0.x, data_t0.edge_index, data_t0.num_nodes)
        else:
            z_pred = model.encode(data_t0.x, data_t0.edge_index, data_t0.num_nodes)
        
        pos_edge = data_t1.edge_index
        
        if sampling_strategy == 'random':
            neg_edge = NegativeSampler.random_negative_sampling(
                pos_edge, data_t1.num_nodes, pos_edge.size(1)
            )
        elif sampling_strategy == 'historical' and train_history is not None:
            neg_edge = NegativeSampler.historical_negative_sampling(
                train_history, pos_edge, data_t1.num_nodes, pos_edge.size(1)
            )
        elif sampling_strategy == 'inductive' and test_history is not None:
            neg_edge = NegativeSampler.inductive_negative_sampling(
                test_history, pos_edge, data_t1.num_nodes, pos_edge.size(1)
            )
        else:
            neg_edge = NegativeSampler.random_negative_sampling(
                pos_edge, data_t1.num_nodes, pos_edge.size(1)
            )
        
        pos_score = model.predict_link(z_pred, pos_edge)
        neg_score = model.predict_link(z_pred, neg_edge)
        
        y_true = torch.cat([
            torch.ones_like(pos_score),
            torch.zeros_like(neg_score)
        ]).cpu().numpy()
        y_pred = torch.cat([pos_score, neg_score]).cpu().numpy()
        
        auc = roc_auc_score(y_true, y_pred)
        ap = average_precision_score(y_true, y_pred)
        
        return {
            'AUC': auc,
            'AP': ap,
            'pos_score_mean': pos_score.mean().item(),
            'neg_score_mean': neg_score.mean().item()
        }

def evaluate_multi_step_enhanced(model, snapshots, years, device, 
                                sampling_strategy='random', model_name="Model"):
    results = []
    
    train_history = []
    test_history = []
    
    for i in range(len(years) - 2):
        t0, t1, t2 = years[i], years[i+1], years[i+2]
        
        if i > 0:
            train_history.append(snapshots[years[i-1]].edge_index)
            test_history.append(snapshots[t1].edge_index)
        
        metrics_1step = evaluate_link_prediction_enhanced(
            model, snapshots[t0], snapshots[t1], device,
            sampling_strategy=sampling_strategy,
            train_history=train_history if len(train_history) > 0 else None,
            test_history=test_history if len(test_history) > 0 else None,
            model_name=model_name
        )
        
        metrics_2step = evaluate_link_prediction_enhanced(
            model, snapshots[t0], snapshots[t2], device,
            sampling_strategy=sampling_strategy,
            train_history=train_history if len(train_history) > 0 else None,
            test_history=test_history if len(test_history) > 0 else None,
            model_name=model_name
        )
        
        results.append({
            'model': model_name,
            'sampling': sampling_strategy,
            'time_window': f"{t0}->{t1}->{t2}",
            '1_step_AUC': metrics_1step['AUC'],
            '1_step_AP': metrics_1step['AP'],
            '2_step_AUC': metrics_2step['AUC'],
            '2_step_AP': metrics_2step['AP']
        })
    
    return pd.DataFrame(results)

# ==========================================
# 6. Training Function
# ==========================================
def train_model(model, snapshots, years, device, epochs=5, lr=0.01, model_name="Model"):
    """モデルの学習"""
    
    if isinstance(model, SimpleMemoryGNN):
        model.init_memory(snapshots[years[0]].num_nodes, device)
    
    # EdgeBankの訓練をスキップ
    if not isinstance(model, (StaticGCN, GCN_LSTM, GraphSAGEModel, SimpleMemoryGNN, InnoVeloODE, GRAND_ODE, GREAD_ODE)):
        print(f"Skipping training for non-parametric model: {model_name}.")
        return []
        
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    train_years = years[-4:] if len(years) >= 4 else years
    
    history = []
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        total_auc = 0
        steps = 0
        
        for i in range(len(train_years) - 1):
            t0, t1 = train_years[i], train_years[i+1]
            data_t0 = snapshots[t0].to(device)
            data_t1 = snapshots[t1].to(device)
            
            optimizer.zero_grad()
            
            # Forward Pass
            if isinstance(model, (InnoVeloODE, GRAND_ODE, GREAD_ODE)):
                t_span = torch.tensor([0.0, 1.0]).to(device)
                z_traj = model(data_t0.x, data_t0.edge_index, t_span, data_t0.num_nodes)
                z_t1_pred = z_traj[-1]
            elif isinstance(model, GCN_LSTM):
                z_t1_pred = model.encode_snapshot(data_t0.x, data_t0.edge_index, data_t0.num_nodes)
            else:
                z_t1_pred = model.encode(data_t0.x, data_t0.edge_index, data_t0.num_nodes)
            
            pos_edge_index = data_t1.edge_index
            neg_edge_index = negative_sampling(
                pos_edge_index, num_nodes=data_t1.num_nodes,
                num_neg_samples=pos_edge_index.size(1)
            )
            
            pos_score = model.predict_link(z_t1_pred, pos_edge_index)
            neg_score = model.predict_link(z_t1_pred, neg_edge_index)
            
            loss = -torch.log(pos_score + 1e-15).mean() - torch.log(1 - neg_score + 1e-15).mean()
            
            # 運動エネルギー正則化 (InnoVeloODEのみ) - インデントと位置を修正
            if isinstance(model, InnoVeloODE) and model.ode_func.ablation_mode != 'no_velocity':
                # z_trajは上のifブロックで定義されている
                z_final = z_traj[-1]
                v_final = z_final[:, model.hidden_dim:] 
                kinetic_energy = torch.mean(v_final ** 2)
                loss += 1.0 * kinetic_energy
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            steps += 1
            
            with torch.no_grad():
                y_true = torch.cat([
                    torch.ones_like(pos_score),
                    torch.zeros_like(neg_score)
                ]).cpu().numpy()
                y_pred = torch.cat([pos_score, neg_score]).cpu().numpy()
                total_auc += roc_auc_score(y_true, y_pred)
        
        avg_loss = total_loss / steps if steps > 0 else 0
        avg_auc = total_auc / steps if steps > 0 else 0
        
        history.append({'epoch': epoch+1, 'loss': avg_loss, 'auc': avg_auc})
        print(f"{model_name} - Epoch {epoch+1:02d}/{epochs} | Loss: {avg_loss:.4f} | AUC: {avg_auc:.4f}")
    
    return history

# ==========================================
# 7. Streamline Visualizer (瞬時速度場)
# ==========================================
class StreamlineVisualizer:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.category_names = ARXIV_CATEGORY_NAMES 

    def _get_embeddings_and_components(self, snapshot):
        data = snapshot.to(self.device)
        self.model.eval()
        
        with torch.no_grad():
            if isinstance(self.model, InnoVeloODE):
                z0 = self.model.encode(data.x, data.num_nodes)
                
                # 瞬時速度場（dz/dt）を取得
                self.model.ode_func.set_component_storage(True)
                dz_dt_t0 = self.model.ode_func(torch.tensor(0.0).to(z0.device), z0)
                self.model.ode_func.set_component_storage(False)
                
                z_start = z0.cpu().numpy()
                labels = data.y.cpu().numpy().flatten()
                
                components_tensor = self.model.ode_func.get_components()
                components = {k: v.cpu().numpy() for k, v in components_tensor.items()}
                
                velocities_t0 = dz_dt_t0.cpu().numpy()
            
            else:
                # ODEモデル以外の場合、ダミー速度場を生成
                z_start = self.model.encode(data.x, data.edge_index, data.num_nodes).cpu().numpy()
                velocities_t0 = np.random.randn(*z_start.shape) * 0.1 
                labels = data.y.cpu().numpy().flatten()
                components = {
                    'dz_dt': velocities_t0,
                    'velocity': velocities_t0 * 0.5,
                    'diffusion': velocities_t0 * 0.3,
                    'decay': velocities_t0 * 0.2
                }
            
            num_cats = self.model.num_cats
            # カテゴリノードのみを抽出
            valid_mask = (labels >= 0) & (labels < num_cats) 
        
        z_start = z_start[valid_mask]
        labels = labels[valid_mask]
        velocities_t0 = velocities_t0[valid_mask]
        for k in components:
            components[k] = components[k][valid_mask]

        if z_start.shape[0] == 0:
            return None, None, None, None
        
        # ノードのサンプリングを導入（最大3000ノード）みやすくするため
        num_points = z_start.shape[0]
        max_points = 3000  # 最大プロット数を設定
        if num_points > max_points:
            print(f"Sampling {max_points} out of {num_points} nodes for visualization...")
            idx = np.random.choice(num_points, max_points, replace=False)
            
            z_start = z_start[idx]
            labels = labels[idx]
            velocities_t0 = velocities_t0[idx]
            for k in components:
                components[k] = components[k][idx]

        # UMAPで2次元に圧縮 (n_neighborsを調整)
        reducer = umap.UMAP(n_components=2, n_neighbors=min(5, z_start.shape[0]-1), min_dist=0.1, random_state=42)
        emb_start = reducer.fit_transform(z_start)
        
        # 瞬時速度ベクトルをUMAP空間で近似
        epsilon = 0.1 
        emb_end_approx = reducer.transform(z_start + epsilon * velocities_t0)
        velocities_umap = (emb_end_approx - emb_start) / epsilon
        
        return emb_start, labels, components, velocities_umap

    def plot_streamline(self, snapshot, title="Tech Trend Streamline", save_prefix="streamline"):
        emb_start, labels, components, velocities_umap = self._get_embeddings_and_components(snapshot)
        if emb_start is None:
            print("No valid data for streamline visualization")
            return

        x, y = emb_start[:, 0], emb_start[:, 1]
        u, v = velocities_umap[:, 0], velocities_umap[:, 1] 
        
        # Net Velocity (dz/dt) の大きさを計算
        net_velocity_magnitude = np.linalg.norm(components['dz_dt'], axis=1)
        
        min_size = 5
        max_size_range = 195

        # 1. ノードサイズの動的スケーリング: 速度の大きさに応じてサイズを100から500の範囲でスケーリング
        if net_velocity_magnitude.max() > 0:
            scaled_size = min_size + max_size_range * (net_velocity_magnitude / net_velocity_magnitude.max())
        else:
            scaled_size = np.full_like(net_velocity_magnitude, 200)

        grid_dim = 50
        xi = np.linspace(x.min(), x.max(), grid_dim)
        yi = np.linspace(y.min(), y.max(), grid_dim)
        
        ui = griddata((x, y), u, (xi[None, :], yi[:, None]), method='linear')
        vi = griddata((x, y), v, (xi[None, :], yi[:, None]), method='linear')
        ui = np.nan_to_num(ui); vi = np.nan_to_num(vi)
        
        speed_grid = np.sqrt(ui**2 + vi**2)
        lw = 3.0 * speed_grid / speed_grid.max() if speed_grid.max() > 0 else 1.0

        fig, ax = plt.subplots(figsize=(20, 20))
        unique_labels = np.unique(labels)
        
        colors = plt.colormaps.get_cmap('tab20c')
        
        # 修正されたノードサイズを使用
        ax.scatter(x, y, c=labels, cmap=colors, s=scaled_size, alpha=0.4, edgecolors='k', linewidths=0.5)
        ax.streamplot(xi, yi, ui, vi, color='black', linewidth=lw, arrowsize=1.5, density=1.0, cmap='plasma')
        
        # 2. ラベルの強調
        for label in unique_labels:
            mask = (labels == label)
            if np.sum(mask) > 0:
                cx = np.mean(x[mask]); cy = np.mean(y[mask])
                cat_id_str = str(int(label))
                ax.text(cx, cy, cat_id_str, fontsize=12, fontweight='bold', 
                        bbox=dict(boxstyle="round,pad=0.5", fc='white', alpha=0.9, 
                                  ec=colors(label), linewidth=2)) # 外枠を強調

        # 保存パスにディレクトリを付加
        final_save_path = os.path.join(OUTPUT_DIR, save_prefix)

        # Raw f-stringによるLaTeXタイトルの表示 (最終修正版)
        ax.set_title(rf"{title} (Instantaneous Velocity Field: $\frac{{dz}}{{dt}}|_{{t=0}}$)", fontsize=18, pad=20)
        ax.set_xlabel("UMAP Dimension 1")
        ax.set_ylabel("UMAP Dimension 2")
        plt.tight_layout()
        plt.savefig(f"{final_save_path}.png", dpi=300); plt.savefig(f"{final_save_path}.pdf", bbox_inches='tight')
        print(f"Saved: {final_save_path}.png and {final_save_path}.pdf")
        plt.close(fig)

    def plot_velocity_components(self, snapshot, title_suffix="Velocity Components", 
                                 save_prefix="velocity_components"):
        emb_start, labels, components, _ = self._get_embeddings_and_components(snapshot)
        if emb_start is None:
            print("No valid data for velocity components visualization")
            return

        x, y = emb_start[:, 0], emb_start[:, 1]
        
        plot_data = {
            'Net Velocity (dz/dt)': np.linalg.norm(components['dz_dt'], axis=1),
            'Diffusion (GNN)': np.linalg.norm(components['diffusion'], axis=1),
            'Intrinsic Velocity': np.linalg.norm(components['velocity'], axis=1),
        }

        fig, axes = plt.subplots(1, 3, figsize=(18, 6))

        min_size = 5
        max_size_range = 195
        
        for ax, (comp_name, comp_magnitude) in zip(axes, plot_data.items()):
            if comp_magnitude.max() > 0:
                norm_magnitude = comp_magnitude / comp_magnitude.max()
            else:
                norm_magnitude = comp_magnitude
            
            # ここでもノードサイズを動的に変更
            scaled_size = min_size + max_size_range * (norm_magnitude)
            
            scatter = ax.scatter(x, y, c=comp_magnitude, cmap='viridis',
                               s=scaled_size, alpha=0.5,
                               edgecolors='k', linewidths=0.5)
            
            fig.colorbar(scatter, ax=ax, label='Magnitude (Strength)')
            ax.set_title(comp_name)
            ax.set_xlabel("UMAP Dimension 1")
            ax.set_ylabel("UMAP Dimension 2")

        final_save_path = os.path.join(OUTPUT_DIR, save_prefix)

        plt.suptitle(f"Velocity Field Components Analysis - {title_suffix}", fontsize=16)
        plt.tight_layout(rect=[0, 0, 1, 0.95])
        plt.savefig(f"{final_save_path}.png", dpi=300); plt.savefig(f"{final_save_path}.pdf", bbox_inches='tight')
        print(f"Saved: {final_save_path}.png and {final_save_path}.pdf")
        plt.close(fig)

# ==========================================
# 8. Graph Structure Visualizer
# ==========================================
class GraphStructureVisualizer:
    def __init__(self, df_nodes, df_edges):
        self.df_nodes = df_nodes
        self.df_edges = df_edges
        self.category_names = ARXIV_CATEGORY_NAMES 
        
    def plot_category_subgraph(self, category_id, num_papers=100):
        cat_name = self.category_names.get(category_id, f'Cat {category_id}')
        print(f"Generating Subgraph Visualization for Category {cat_name}...")
        
        category_papers = self.df_nodes[self.df_nodes['category'] == category_id]
        if category_papers.empty:
            print(f"Category {category_id} not found.")
            return

        category_papers = category_papers.sort_values(by='year', ascending=False).head(num_papers)
        paper_ids = set(category_papers['node_id'].values)
        sub_edges = self.df_edges[
            (self.df_edges['source'].isin(paper_ids)) & 
            (self.df_edges['target'].isin(paper_ids))
        ]

        G = nx.DiGraph()
        G.add_edges_from(sub_edges[['source', 'target']].values)
        
        if not G.nodes:
            print(f"No citation links found within the top {num_papers} papers of Category {category_id}.")
            return
            
        node_years = category_papers.set_index('node_id')['year'].to_dict()
        nx.set_node_attributes(G, node_years, 'year')

        fig, ax = plt.subplots(figsize=(20, 20))
        pos = nx.spring_layout(G, k=0.15, iterations=50, seed=42)
        years_list = [G.nodes[n]['year'] for n in G.nodes if 'year' in G.nodes[n]]
        
        if not years_list:
             cmap = plt.cm.get_cmap('YlOrRd')
             norm = plt.Normalize(vmin=0, vmax=1)
        else:
             cmap = plt.cm.get_cmap('YlOrRd')
             norm = plt.Normalize(vmin=min(years_list), vmax=max(years_list))
             node_colors = years_list
        
        nx.draw_networkx_nodes(G, pos, node_size=150, node_color=years_list, 
                              cmap=cmap, edgecolors='gray', linewidths=0.5, ax=ax)
        nx.draw_networkx_edges(G, pos, arrowstyle='->', arrowsize=10, 
                              edge_color='gray', alpha=0.6, ax=ax)
        node_labels = {node: str(node) for node in G.nodes()}
        nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=8, font_color='black', ax=ax)
        
        if years_list:
            sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
            sm.set_array(years_list)
            plt.colorbar(sm, ax=ax, orientation='vertical', label='Publication Year')

        ax.set_title(f"Citation Subgraph: {cat_name} (Top {num_papers} Papers)", 
                    fontsize=20)
        ax.axis('off')
        
        final_file_prefix = os.path.join(OUTPUT_DIR, f"subgraph_cat_{category_id}")

        plt.tight_layout()
        plt.savefig(f"{final_file_prefix}.png", dpi=300)
        plt.savefig(f"{final_file_prefix}.pdf", bbox_inches='tight')
        print(f"Saved: {final_file_prefix}.png and {final_file_prefix}.pdf")
        plt.close(fig)

# ==========================================
# 9. Results Comparison & Visualization
# ==========================================
def plot_sampling_comparison(all_results, save_prefix="sampling_comparison"):
    """複数のサンプリング戦略での比較可視化"""
    sampling_strategies = all_results['sampling'].unique()
    
    fig, axes = plt.subplots(2, 2, figsize=(18, 14))
    metrics = ['1_step_AUC', '1_step_AP', '2_step_AUC', '2_step_AP']
    titles = ['1-Step AUC', '1-Step AP', '2-Step AUC', '2-Step AP']
    
    for ax, metric, title in zip(axes.flat, metrics, titles):
        for strategy in sampling_strategies:
            strategy_data = all_results[all_results['sampling'] == strategy]
            for model_name in strategy_data['model'].unique():
                model_data = strategy_data[strategy_data['model'] == model_name]
                label = f"{model_name} ({strategy})"
                linestyle = '-' if strategy == 'random' else '--' if strategy == 'historical' else ':'
                ax.plot(range(len(model_data)), model_data[metric], 
                       marker='o', label=label, linewidth=2, linestyle=linestyle)
        
        ax.set_xlabel('Time Window Index', fontsize=12)
        ax.set_ylabel(title, fontsize=12)
        ax.set_title(title, fontsize=14, fontweight='bold')
        ax.legend(loc='best', fontsize=8, ncol=2)
        ax.grid(True, alpha=0.3)
        ax.set_ylim([0, 1])
    
    plt.suptitle('Performance Comparison Across Sampling Strategies', fontsize=16, y=1.00)
    plt.tight_layout()
    plt.savefig(f"{save_prefix}.png", dpi=300, bbox_inches='tight')
    plt.savefig(f"{save_prefix}.pdf", bbox_inches='tight')
    print(f"Saved: {save_prefix}.png and {save_prefix}.pdf")
    plt.close(fig)

def print_summary_table(all_results):
    """結果の要約表を出力"""
    summary = all_results.groupby(['model', 'sampling']).agg({
        '1_step_AUC': ['mean', 'std'],
        '1_step_AP': ['mean', 'std'],
        '2_step_AUC': ['mean', 'std'],
        '2_step_AP': ['mean', 'std']
    }).round(4)
    
    print("\n" + "="*100)
    print("MODEL COMPARISON SUMMARY (BY SAMPLING STRATEGY)")
    print("="*100)
    print(summary)
    print("="*100 + "\n")
    
    return summary

# ==========================================
# 10. Main Execution Pipeline
# ==========================================
def main():
    print("\n" + "="*100)
    print("DYNAMIC GRAPH LEARNING WITH ADVANCED EVALUATION")
    print("="*100 + "\n")
    
    # ハイパーパラメータ
    HIDDEN_DIM = 32
    EPOCHS = 5
    LR = 0.01
    USE_GAT = True
    GAT_HEADS = 2   
    
    # 評価するサンプリング戦略
    SAMPLING_STRATEGIES = ['random', 'historical', 'inductive']
    
    # 1. データロード
    print("STEP 1: Loading Data...")
    factory = UniversalDataFactory('ogbn-arxiv')
    raw_data = factory.load_data()
    
    # 2. グラフ構築
    print("\nSTEP 2: Building Dynamic Graphs...")
    builder = DynamicGraphBuilder(raw_data)
    snapshots, num_cats, feat_dim = builder.build_snapshots()
    years = sorted(snapshots.keys())
    
    if len(years) < 2:
        print("Error: Not enough time steps for training.")
        return
    
    print(f"✓ Time snapshots: {years}")
    print(f"✓ Categories: {num_cats}, Feature dim: {feat_dim}")
    
    # 3. モデル初期化
    print("\nSTEP 3: Initializing Models (Baselines & Ablation)...")
    models = {
        'Static GCN': StaticGCN(num_cats, feat_dim, HIDDEN_DIM).to(device),
        'GRAND (Diffusion)': GRAND_ODE(num_cats, feat_dim, HIDDEN_DIM).to(device),
        'GREAD (React-Diff)': GREAD_ODE(num_cats, feat_dim, HIDDEN_DIM).to(device),
        
        # --- InnoVeloODE Ablation Studies ---
        'InnoVelo (Full)': InnoVeloODE(
            num_cats, feat_dim, HIDDEN_DIM, use_gat=USE_GAT, gat_heads=GAT_HEADS,
            ablation_mode='adaptive_decay'
        ).to(device),
        
        'InnoVelo (No Decay)': InnoVeloODE(
            num_cats, feat_dim, HIDDEN_DIM, use_gat=USE_GAT, gat_heads=GAT_HEADS,
            ablation_mode='no_decay'
        ).to(device),
        
        'InnoVelo (No Velocity)': InnoVeloODE(
            num_cats, feat_dim, HIDDEN_DIM, use_gat=USE_GAT, gat_heads=GAT_HEADS,
            ablation_mode='no_velocity'
        ).to(device)
    }
    
    for name in models.keys():
        print(f"  ✓ {name}")
    
    # 4. 学習
    print("\nSTEP 4: Training Models...")
    print("-" * 100)
    
    for name, model in models.items():
        
        torch.cuda.empty_cache()
        
        print(f"\nTraining {name}...")
        train_model(model, snapshots, years, device, 
                             epochs=EPOCHS, lr=LR, model_name=name)
    
    # 5. 複数のサンプリング戦略での評価
    print("\nSTEP 5: Evaluating Models with Multiple Sampling Strategies...")
    print("-" * 100)
    
    all_results = []
    for strategy in SAMPLING_STRATEGIES:
        print(f"\n### Evaluating with {strategy.upper()} sampling ###")
        for name, model in models.items():
            print(f"  Evaluating {name}...")
            results = evaluate_multi_step_enhanced(
                model, snapshots, years, device, 
                sampling_strategy=strategy, model_name=name
            )
            all_results.append(results)
    
    all_results_df = pd.concat(all_results, ignore_index=True)
    
    # 6. 結果の可視化と要約
    print("\nSTEP 6: Visualizing Results and Summarizing...")
    print("-" * 100)
    
    plot_sampling_comparison(all_results_df, "sampling_strategy_comparison")
    summary = print_summary_table(all_results_df)
    
    all_results_df.to_csv("evaluation_results_enhanced.csv", index=False)
    summary.to_csv("summary_results_enhanced.csv")
    print("✓ Saved: evaluation_results_enhanced.csv, summary_results_enhanced.csv")
    
    # 7. Neural ODE (GAT版) の詳細可視化
    print("\nSTEP 7: Detailed Visualization for InnoVelo (Full)...")
    print("-" * 100)
    
    ode_gat_model = models['InnoVelo (Full)']
    latest_snapshot = snapshots[years[-1]]
    
    visualizer = StreamlineVisualizer(ode_gat_model, device)
    visualizer.plot_streamline(latest_snapshot, 
                              title="InnoVelo (Full): Tech Trend Streamline",
                              save_prefix="innovelo_full_streamline")
    visualizer.plot_velocity_components(latest_snapshot,
                                       title_suffix="InnoVelo (Full)",
                                       save_prefix="innovelo_full_velocity_components")
    
    
    # 9. グラフ構造の可視化 (Top Category)
    print("\nSTEP 8: Graph Structure Visualization (Top Category)...")
    print("-" * 100)
    
    graph_viz = GraphStructureVisualizer(raw_data['df_nodes'], raw_data['df_edges'])

    ALL_CATEGORIES = range(40)

    for category_id in ALL_CATEGORIES:
        graph_viz.plot_category_subgraph(category_id=category_id, num_papers=100)
    
    print("\n" + "="*100)
    print("ALL VISUALIZATIONS COMPLETED!")
    
    # 10. 重要な発見の要約
    print("\n" + "="*100)
    print("ALL EXPERIMENTS COMPLETED SUCCESSFULLY!")
    print("="*100)

if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda
✓ Output directory created: visualizations_output

✓ OGB-Arxiv Categories Loaded: 40 unique labels.

DYNAMIC GRAPH LEARNING WITH ADVANCED EVALUATION

STEP 1: Loading Data...
Loading dataset: ogbn-arxiv...
✓ Loaded OGB dataset successfully
✓ Data loaded: 169343 papers, 40 categories.

STEP 2: Building Dynamic Graphs...
Building temporal snapshots...
✓ Built 6 snapshots.
✓ Time snapshots: [2015, 2016, 2017, 2018, 2019, 2020]
✓ Categories: 40, Feature dim: 128

STEP 3: Initializing Models (Baselines & Ablation)...
✓ Using GAT (2 heads) - Mode: adaptive_decay
✓ Using GAT (2 heads) - Mode: no_decay
✓ Using GAT (2 heads) - Mode: no_velocity
  ✓ Static GCN
  ✓ GRAND (Diffusion)
  ✓ GREAD (React-Diff)
  ✓ InnoVelo (Full)
  ✓ InnoVelo (No Decay)
  ✓ InnoVelo (No Velocity)

STEP 4: Training Models...
----------------------------------------------------------------------------------------------------

Training Static GCN...
Static GCN - Epoch 01/5 | Loss: 1.2685 | AUC: 0.7263
S

  warn(f"n_jobs value {self.n_jobs} overridden to 1 by setting random_state. Use no seed for parallelism.")


Saved: visualizations_output/innovelo_full_streamline.png and visualizations_output/innovelo_full_streamline.pdf
Sampling 3000 out of 169383 nodes for visualization...


  warn(f"n_jobs value {self.n_jobs} overridden to 1 by setting random_state. Use no seed for parallelism.")


Saved: visualizations_output/innovelo_full_velocity_components.png and visualizations_output/innovelo_full_velocity_components.pdf

STEP 8: Graph Structure Visualization (Top Category)...
----------------------------------------------------------------------------------------------------
Generating Subgraph Visualization for Category arxiv cs na...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_0.png and visualizations_output/subgraph_cat_0.pdf
Generating Subgraph Visualization for Category arxiv cs mm...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_1.png and visualizations_output/subgraph_cat_1.pdf
Generating Subgraph Visualization for Category arxiv cs lo...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_2.png and visualizations_output/subgraph_cat_2.pdf
Generating Subgraph Visualization for Category arxiv cs cy...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_3.png and visualizations_output/subgraph_cat_3.pdf
Generating Subgraph Visualization for Category arxiv cs cr...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_4.png and visualizations_output/subgraph_cat_4.pdf
Generating Subgraph Visualization for Category arxiv cs dc...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_5.png and visualizations_output/subgraph_cat_5.pdf
Generating Subgraph Visualization for Category arxiv cs hc...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_6.png and visualizations_output/subgraph_cat_6.pdf
Generating Subgraph Visualization for Category arxiv cs ce...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_7.png and visualizations_output/subgraph_cat_7.pdf
Generating Subgraph Visualization for Category arxiv cs ni...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_8.png and visualizations_output/subgraph_cat_8.pdf
Generating Subgraph Visualization for Category arxiv cs cc...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_9.png and visualizations_output/subgraph_cat_9.pdf
Generating Subgraph Visualization for Category arxiv cs ai...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_10.png and visualizations_output/subgraph_cat_10.pdf
Generating Subgraph Visualization for Category arxiv cs ma...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_11.png and visualizations_output/subgraph_cat_11.pdf
Generating Subgraph Visualization for Category arxiv cs gl...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_12.png and visualizations_output/subgraph_cat_12.pdf
Generating Subgraph Visualization for Category arxiv cs ne...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_13.png and visualizations_output/subgraph_cat_13.pdf
Generating Subgraph Visualization for Category arxiv cs sc...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_14.png and visualizations_output/subgraph_cat_14.pdf
Generating Subgraph Visualization for Category arxiv cs ar...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_15.png and visualizations_output/subgraph_cat_15.pdf
Generating Subgraph Visualization for Category arxiv cs cv...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_16.png and visualizations_output/subgraph_cat_16.pdf
Generating Subgraph Visualization for Category arxiv cs gr...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_17.png and visualizations_output/subgraph_cat_17.pdf
Generating Subgraph Visualization for Category arxiv cs et...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_18.png and visualizations_output/subgraph_cat_18.pdf
Generating Subgraph Visualization for Category arxiv cs sy...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_19.png and visualizations_output/subgraph_cat_19.pdf
Generating Subgraph Visualization for Category arxiv cs cg...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_20.png and visualizations_output/subgraph_cat_20.pdf
Generating Subgraph Visualization for Category arxiv cs oh...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_21.png and visualizations_output/subgraph_cat_21.pdf
Generating Subgraph Visualization for Category arxiv cs pl...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_22.png and visualizations_output/subgraph_cat_22.pdf
Generating Subgraph Visualization for Category arxiv cs se...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_23.png and visualizations_output/subgraph_cat_23.pdf
Generating Subgraph Visualization for Category arxiv cs lg...
No citation links found within the top 100 papers of Category 24.
Generating Subgraph Visualization for Category arxiv cs sd...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_25.png and visualizations_output/subgraph_cat_25.pdf
Generating Subgraph Visualization for Category arxiv cs si...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_26.png and visualizations_output/subgraph_cat_26.pdf
Generating Subgraph Visualization for Category arxiv cs ro...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_27.png and visualizations_output/subgraph_cat_27.pdf
Generating Subgraph Visualization for Category arxiv cs it...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_28.png and visualizations_output/subgraph_cat_28.pdf
Generating Subgraph Visualization for Category arxiv cs pf...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_29.png and visualizations_output/subgraph_cat_29.pdf
Generating Subgraph Visualization for Category arxiv cs cl...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_30.png and visualizations_output/subgraph_cat_30.pdf
Generating Subgraph Visualization for Category arxiv cs ir...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_31.png and visualizations_output/subgraph_cat_31.pdf
Generating Subgraph Visualization for Category arxiv cs ms...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_32.png and visualizations_output/subgraph_cat_32.pdf
Generating Subgraph Visualization for Category arxiv cs fl...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_33.png and visualizations_output/subgraph_cat_33.pdf
Generating Subgraph Visualization for Category arxiv cs ds...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_34.png and visualizations_output/subgraph_cat_34.pdf
Generating Subgraph Visualization for Category arxiv cs os...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_35.png and visualizations_output/subgraph_cat_35.pdf
Generating Subgraph Visualization for Category arxiv cs gt...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_36.png and visualizations_output/subgraph_cat_36.pdf
Generating Subgraph Visualization for Category arxiv cs db...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_37.png and visualizations_output/subgraph_cat_37.pdf
Generating Subgraph Visualization for Category arxiv cs dl...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_38.png and visualizations_output/subgraph_cat_38.pdf
Generating Subgraph Visualization for Category arxiv cs dm...


  cmap = plt.cm.get_cmap('YlOrRd')


Saved: visualizations_output/subgraph_cat_39.png and visualizations_output/subgraph_cat_39.pdf

ALL VISUALIZATIONS COMPLETED!

ALL EXPERIMENTS COMPLETED SUCCESSFULLY!
