In [None]:
"""
Elliptic++ っぽいトランザクショングラフに
GNN（GCN）を適用して可視化する完全版スクリプト。

1) pandas で TXT を読み込み
2) PyTorch Geometric の Data へ変換
3) GCN でノード分類を学習
4) t-SNE で埋め込み可視化
5) サブグラフ可視化
6) 混同行列を表示
"""

import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
from torch_geometric.nn import GCNConv

from sklearn.manifold import TSNE
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# ====== ファイルパス ======
TXS_FEATURES = "txs_features.txt"
TXS_CLASSES  = "txs_classes.txt"
TXS_EDGES    = "txs_edgelist.txt"


# ==========================
# ユーティリティ
# ==========================
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def train_val_test_split(num_nodes: int,
                         train_ratio: float = 0.6,
                         val_ratio: float = 0.2):
    """
    ノードインデックスを train / val / test にランダム分割。
    """
    assert train_ratio + val_ratio < 1.0
    perm = torch.randperm(num_nodes)
    n_train = int(num_nodes * train_ratio)
    n_val   = int(num_nodes * val_ratio)

    train_idx = perm[:n_train]
    val_idx   = perm[n_train:n_train + n_val]
    test_idx  = perm[n_train + n_val:]

    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    val_mask   = torch.zeros(num_nodes, dtype=torch.bool)
    test_mask  = torch.zeros(num_nodes, dtype=torch.bool)

    train_mask[train_idx] = True
    val_mask[val_idx]     = True
    test_mask[test_idx]   = True

    return train_mask, val_mask, test_mask


# ==========================
# データ読み込み & Data 変換
# ==========================
def load_raw_data():
    print("Loading txs_features...")
    df_features = pd.read_csv(TXS_FEATURES)
    print("  shape:", df_features.shape)
    print("  columns:", list(df_features.columns[:10]), "...")

    print("\nLoading txs_classes...")
    df_classes = pd.read_csv(TXS_CLASSES)
    print("  shape:", df_classes.shape)
    print("  columns:", list(df_classes.columns))
    if "class" in df_classes.columns:
        print("  class unique:", df_classes["class"].unique())

    print("\nLoading txs_edgelist...")
    df_edges = pd.read_csv(TXS_EDGES)
    print("  shape:", df_edges.shape)
    print("  columns:", list(df_edges.columns))

    # --- 簡単な概要 ---
    print("\n=== Features overview ===")
    print(df_features.head())
    print(df_features.describe(include="all"))
    print("missing values (total):",
          int(df_features.isnull().sum().sum()))

    print("\n=== Features: missing per column (top 10) ===")
    print((df_features.isnull().sum()).sort_values(
        ascending=False
    ).head(10))

    print("\n=== Classes overview ===")
    if "class" in df_classes.columns:
        print(df_classes["class"].value_counts())
    else:
        print("WARNING: 'class' column not found in classes file.")

    print("\n=== Edges overview ===")
    print(df_edges.head())

    return df_features, df_classes, df_edges


def build_pyg_data(df_features: pd.DataFrame,
                   df_classes: pd.DataFrame,
                   df_edges: pd.DataFrame) -> Data:
    """
    pandas DataFrame から PyG の Data を作る。
    - 先頭列をノードIDとみなす
    - 特徴量は残りの数値列
    - クラスは df_classes の 'class' をカテゴリコード化
    """
    # ---- ノードID列の推定 ----
    id_col_feat = df_features.columns[0]  # 1列目をIDとみなす
    print(f"\n[build_pyg_data] using '{id_col_feat}' as node id column.")

    # ---- 特徴量行列 X ----
    feature_cols = [c for c in df_features.columns if c != id_col_feat]
    x_np = df_features[feature_cols].to_numpy(dtype=float)
    x = torch.tensor(x_np, dtype=torch.float)
    num_nodes = x.size(0)
    print("[build_pyg_data] x shape:", x.shape)

    # ---- ノードID -> インデックスのマッピング ----
    node_ids = df_features[id_col_feat].values
    id2idx = {int(nid): i for i, nid in enumerate(node_ids)}

    # ---- クラスラベル y ----
    if "class" not in df_classes.columns:
        raise ValueError("df_classes に 'class' 列がありません。")

    # classes 側の ID 列を決める（features 側と同じ名前を優先）
    if id_col_feat in df_classes.columns:
        id_col_cls = id_col_feat
    else:
        # 苦し紛れだが、最初の列をIDとみなす
        id_col_cls = df_classes.columns[0]
        print(f"[build_pyg_data] WARNING: using '{id_col_cls}' as id col for classes.")

    # classes を features の順番にそろえる
    cls_series = (
        df_classes
        .set_index(id_col_cls)["class"]
        .reindex(node_ids)
    )
    # 未ラベルは NaN になる
    labeled_mask_np = ~cls_series.isna().values

    # カテゴリコード化（文字列ラベル OK）
    cls_cat = cls_series.astype("category")
    y_codes = cls_cat.cat.codes.to_numpy()  # NaN -> -1
    y = torch.tensor(y_codes, dtype=torch.long)

    # ラベルのカテゴリ一覧
    classes = list(cls_cat.cat.categories)
    print("[build_pyg_data] classes:", classes)
    print("[build_pyg_data] labeled nodes:",
          int(labeled_mask_np.sum()), "/", num_nodes)

    # ---- エッジ edge_index ----
    edge_cols = df_edges.columns[:2]  # 最初の2列を (src, dst) とみなす
    src_raw = df_edges[edge_cols[0]].values
    dst_raw = df_edges[edge_cols[1]].values

    src_idx = []
    dst_idx = []
    missing_edges = 0
    for u, v in zip(src_raw, dst_raw):
        u = int(u)
        v = int(v)
        if u in id2idx and v in id2idx:
            src_idx.append(id2idx[u])
            dst_idx.append(id2idx[v])
        else:
            missing_edges += 1

    if missing_edges > 0:
        print(f"[build_pyg_data] skipped edges with unknown nodes: {missing_edges}")

    edge_index = torch.tensor([src_idx, dst_idx], dtype=torch.long)
    print("[build_pyg_data] edge_index shape:", edge_index.shape)

    # ---- Data オブジェクト ----
    data = Data(x=x, edge_index=edge_index, y=y)

    # マスク作成（ラベル付きノードのみ学習対象）
    train_mask, val_mask, test_mask = train_val_test_split(num_nodes)
    labeled_mask = torch.tensor(labeled_mask_np, dtype=torch.bool)

    data.train_mask = train_mask & labeled_mask
    data.val_mask   = val_mask   & labeled_mask
    data.test_mask  = test_mask  & labeled_mask

    print("[build_pyg_data] train/val/test labeled counts:",
          int(data.train_mask.sum()),
          int(data.val_mask.sum()),
          int(data.test_mask.sum()))

    # エッジ統計
    print("[build_pyg_data] duplicate edges:",
          int(df_edges.duplicated(subset=edge_cols).sum()))

    try:
        edge_pairs = set(map(tuple, df_edges[edge_cols].values))
        reverse_count = sum((v, u) in edge_pairs for (u, v) in edge_pairs)
        print("[build_pyg_data] reverse edge count:", reverse_count)
    except Exception:
        print("[build_pyg_data] reverse edge check failed")

    return data, classes


# ==========================
# GCN モデル
# ==========================
class GCN(nn.Module):
    def __init__(self, in_dim: int, hidden_dim: int, out_dim: int, dropout: float = 0.5):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.lin   = nn.Linear(hidden_dim, out_dim)
        self.dropout = dropout

    def forward(self, x, edge_index, return_embeddings=False):
        # 2層GCN + 全結合
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        embeddings = x  # ここを可視化に使う

        logits = self.lin(x)
        if return_embeddings:
            return logits, embeddings
        else:
            return logits


# ==========================
# 学習 & 評価
# ==========================
def evaluate(model: nn.Module, data: Data, split: str):
    model.eval()
    device = next(model.parameters()).device

    x = data.x.to(device)
    edge_index = data.edge_index.to(device)
    y = data.y.to(device)

    if split == "train":
        mask = data.train_mask
    elif split == "val":
        mask = data.val_mask
    elif split == "test":
        mask = data.test_mask
    else:
        raise ValueError("split must be 'train', 'val' or 'test'")

    mask = mask.to(device)
    with torch.no_grad():
        logits = model(x, edge_index)
        pred = logits.argmax(dim=-1)

    if mask.sum() == 0:
        return None, None

    loss = F.cross_entropy(
        logits[mask],
        y[mask],
        reduction="mean"
    )
    correct = (pred[mask] == y[mask]).sum().item()
    acc = correct / int(mask.sum())
    return loss.item(), acc


def train_gcn(data: Data, num_classes: int,
              hidden_dim: int = 64,
              lr: float = 1e-3,
              epochs: int = 50):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("[train_gcn] device:", device)

    model = GCN(
        in_dim=data.num_node_features,
        hidden_dim=hidden_dim,
        out_dim=num_classes
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)

    data = data.to(device)

    for epoch in range(1, epochs + 1):
        model.train()
        optimizer.zero_grad()

        logits = model(data.x, data.edge_index)
        # train_mask かつ ラベル付きのノードのみ
        mask = data.train_mask
        loss = F.cross_entropy(
            logits[mask],
            data.y[mask],
            reduction="mean"
        )
        loss.backward()
        optimizer.step()

        if epoch % 10 == 0 or epoch == 1:
            train_loss, train_acc = evaluate(model, data, "train")
            val_loss, val_acc     = evaluate(model, data, "val")
            print(f"[Epoch {epoch:03d}] "
                  f"train_loss={train_loss:.4f} acc={train_acc:.4f} | "
                  f"val_loss={val_loss:.4f} acc={val_acc:.4f}")

    # 最終テスト
    test_loss, test_acc = evaluate(model, data, "test")
    print(f"\n[Test] loss={test_loss:.4f} acc={test_acc:.4f}")

    return model


# ==========================
# 可視化 (t-SNE / サブグラフ / 混同行列)
# ==========================
def visualize_embeddings_tsne(model: nn.Module, data: Data,
                              title: str = "GNN embeddings (t-SNE)",
                              max_nodes: int = 2000):
    device = next(model.parameters()).device
    model.eval()

    x = data.x.to(device)
    edge_index = data.edge_index.to(device)
    y = data.y.cpu().numpy()

    with torch.no_grad():
        logits, z = model(x, edge_index, return_embeddings=True)
        y_pred = logits.argmax(dim=-1).cpu().numpy()

    z_np = z.cpu().numpy()
    N = z_np.shape[0]

    # ノード数が多い場合はサンプリング
    if N > max_nodes:
        idx = np.random.choice(N, size=max_nodes, replace=False)
        z_np = z_np[idx]
        y = y[idx]
        y_pred = y_pred[idx]
        print(f"[t-SNE] sampled {max_nodes}/{N} nodes for visualization.")
    else:
        idx = np.arange(N)

    tsne = TSNE(
        n_components=2,
        random_state=0,
        perplexity=30,
        init="pca",
        learning_rate="auto"
    )
    z_2d = tsne.fit_transform(z_np)

    # クラスごとに色分け
    plt.figure(figsize=(8, 6))
    classes = np.unique(y)
    for c in classes:
        mask = (y == c)
        plt.scatter(
            z_2d[mask, 0],
            z_2d[mask, 1],
            s=8,
            alpha=0.6,
            label=f"class {c}"
        )
    plt.title(title)
    plt.xlabel("t-SNE dim 1")
    plt.ylabel("t-SNE dim 2")
    plt.legend()
    plt.tight_layout()
    plt.show()

    # 正解 / 誤り
    correct = (y == y_pred)
    plt.figure(figsize=(8, 6))
    plt.scatter(
        z_2d[correct, 0],
        z_2d[correct, 1],
        s=8,
        alpha=0.6,
        label="correct"
    )
    plt.scatter(
        z_2d[~correct, 0],
        z_2d[~correct, 1],
        s=16,
        alpha=0.8,
        marker="x",
        label="wrong"
    )
    plt.title("Correct vs wrong predictions")
    plt.xlabel("t-SNE dim 1")
    plt.ylabel("t-SNE dim 2")
    plt.legend()
    plt.tight_layout()
    plt.show()


def visualize_subgraph_with_predictions(model: nn.Module,
                                        data: Data,
                                        num_nodes: int = 300,
                                        use_pred: bool = True,
                                        title: str = "Subgraph with GNN predictions"):
    device = next(model.parameters()).device
    model.eval()

    x = data.x.to(device)
    edge_index = data.edge_index.to(device)
    y_true = data.y.cpu().numpy()

    with torch.no_grad():
        logits = model(x, edge_index)
        y_pred = logits.argmax(dim=-1).cpu().numpy()

    labels_for_color = y_pred if use_pred else y_true

    # PyG -> NetworkX
    G = to_networkx(
        data,
        to_undirected=True,
        node_attrs=None,
        edge_attrs=None
    )

    all_nodes = np.array(G.nodes())
    if len(all_nodes) > num_nodes:
        sampled_nodes = np.random.choice(all_nodes, size=num_nodes, replace=False)
    else:
        sampled_nodes = all_nodes

    H = G.subgraph(sampled_nodes)
    pos = nx.spring_layout(H, seed=0)

    node_colors = [labels_for_color[n] for n in H.nodes()]

    plt.figure(figsize=(8, 8))
    nodes = nx.draw_networkx_nodes(
        H,
        pos,
        node_size=50,
        node_color=node_colors,
        cmap="tab10"
    )
    nx.draw_networkx_edges(
        H,
        pos,
        width=0.5,
        alpha=0.5
    )
    plt.colorbar(nodes, label="class (true/pred)")
    plt.axis("off")
    plt.title(title)
    plt.tight_layout()
    plt.show()


def visualize_confusion_matrix(model: nn.Module, data: Data,
                               classes=None,
                               title: str = "Confusion matrix"):
    device = next(model.parameters()).device
    model.eval()

    x = data.x.to(device)
    edge_index = data.edge_index.to(device)
    y_true = data.y.cpu().numpy()

    with torch.no_grad():
        logits = model(x, edge_index)
        y_pred = logits.argmax(dim=-1).cpu().numpy()

    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(
        confusion_matrix=cm,
        display_labels=classes if classes is not None else None
    )
    disp.plot(values_format="d")
    plt.title(title)
    plt.tight_layout()
    plt.show()


# ==========================
# main
# ==========================
def main():
    set_seed(42)

    # 1) pandas で読み込み & 概要
    df_features, df_classes, df_edges = load_raw_data()

    # 2) PyG Data 化
    data, classes = build_pyg_data(df_features, df_classes, df_edges)

    # 3) GCN で学習
    model = train_gcn(data, num_classes=len(classes), hidden_dim=64,
                      lr=1e-3, epochs=50)

    # 4) 可視化
    visualize_embeddings_tsne(model, data)
    visualize_subgraph_with_predictions(
        model,
        data,
        num_nodes=300,
        use_pred=True,
        title="Subgraph with predicted labels"
    )
    visualize_confusion_matrix(model, data, classes=classes)

    print("\nDone.")


if __name__ == "__main__":
    main()