In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import torch
import sys
from torch_geometric.loader import DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import umap
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from matplotlib.ticker import FuncFormatter
import numpy as np
import json

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path: 
    sys.path.insert(0, project_root)

from script_classification.utilities.graph_stats import *
from script_classification.data_loader import BitcoinScriptsDataset
from script_classification.models import GraphEncoder
from script_classification.engine import *
from script_classification.losses import *
from script_classification.utilities.graph_ops import *
from script_classification.evaluation.metrics import *
from script_classification.evaluation.embeddings import *
from script_classification.evaluation.clustering import *


In [None]:
sns.set_theme()

In [None]:
SEED = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
config = json.load(open("config.json"))
data_root = config["data_root"]
saves_root = config["saves_root"]
out_dim = config["out_dim"]
batch_size = config["batch_size"]
hidden_channels = config["hidden_channels"]  # note that this has to be divisible by HEADS
model_save_path = os.path.join(saves_root, config["model_save_filename"])

In [None]:
dataset = BitcoinScriptsDataset(root=data_root)
EDGE_DIM = getattr(dataset, "num_edge_features", dataset[0].edge_attr.size(1))

In [None]:
graph_id_to_idx_map = {data.graph_id: i for i, data in enumerate(dataset)}

In [None]:
n = len(dataset)
n_train = int(0.7 * n)
n_val   = int(0.15 * n)
n_test  = n - n_train - n_val

gen = torch.Generator().manual_seed(SEED)
train_raw, val_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [n_train, n_val, n_test], generator=gen
)

In [None]:
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    pin_memory=True,
)

In [None]:
encoder = GraphEncoder(
    in_channels=dataset.num_node_features,
    hidden_channels=hidden_channels, 
    out_channels=out_dim,
    edge_dim=EDGE_DIM
).to(device)

In [None]:
ckpt = torch.load(model_save_path, map_location=device)
encoder.load_state_dict(ckpt["model_state"])

embeddings_for_analysis = embed_roots(encoder, test_loader, device)

In [None]:
optimal_k, inertias, silhouettes, k_range, best_silhouettes = find_optimal_clusters_root(embeddings_for_analysis, max_k=15, random_state=33)
print(f"Optimal number of clusters: {optimal_k}")

In [None]:
plot_find_optimal_clusters_root(optimal_k, inertias, silhouettes, k_range)

In [None]:
def visualize_embeddings_umap(graph_embeddings, cluster_labels, title="Graph Embedding Clusters (UMAP)", figsize=(5, 5)):
    reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42, n_jobs=1)
    embeddings_2d = reducer.fit_transform(graph_embeddings)

    plt.figure(figsize=figsize)

    scatter = sns.scatterplot(
        x=embeddings_2d[:, 0],
        y=embeddings_2d[:, 1],
        hue=cluster_labels,
        palette="tab20",
        legend="full",
        s=50,
        alpha=0.8
    )

    plt.legend(title="Cluster ID", loc="center left", bbox_to_anchor=(1.02, 0.5), frameon=False)

    scatter.set_title(title)
    plt.xlabel("UMAP Component 1")
    plt.ylabel("UMAP Component 2")
    plt.show()

In [None]:
k_means = KMeans(n_clusters=optimal_k, random_state=11, n_init=10)
final_clusters = k_means.fit_predict(embeddings_for_analysis)
final_score = silhouette_score(embeddings_for_analysis, final_clusters, metric="cosine")

print(f"Final Silhouette (cosine) with K={optimal_k}: {final_score:.4f}")

In [None]:
visualize_embeddings_umap(embeddings_for_analysis, final_clusters, title=f"Root Clusters (K={optimal_k})")

## Compare Clusters with External Labels

One of the publicly available sources for labeled Bitcoin addresses is 
[WalletExplorer](https://www.walletexplorer.com). 
To help explore our graph's clusters, we provide the WalletExplorer 
labels and the code used to process them.
For convenience, we offer several resources:

- Complete raw data retrieved from WalletExplorer.
- Pre-filtered version of the labels containing only addresses 
  that are present in the sampled communities. 
  This is the file used in this notebook.
- The scripts used to both retrieve and filter the data.

The notebook loads these pre-filtered labels. 
It then uses the assigned categories from WalletExplorer 
(e.g., exchange, service, mining pool) 
to color-code the embeddings of the root nodes in the visualization. 

In [None]:
wallet_explorer_clusters_wallet_label = []
wallet_explorer_clusters_wallet_category = []
for graph in test_dataset:
    graph_dir = os.path.join(data_root, "raw", graph.graph_id)
    
    root_node_idx = pd.read_csv(os.path.join(graph_dir, "metadata.tsv"), sep="\t")["RootNodeIdx"][0]
    graph_w_wallet_explorer_labels = pd.read_csv(os.path.join(graph_dir, "BitcoinScriptNode_Annotated.tsv"), sep="\t")
    
    root_node = graph_w_wallet_explorer_labels.iloc[root_node_idx]

    wallet_explorer_clusters_wallet_label.append(root_node["WalletExplorer_WalletLabel"])    
    wallet_explorer_clusters_wallet_category.append(root_node["WalletExplorer_Category"])

In [None]:
def visualize_umap_multiple_label_groups(graph_embeddings, graph_cluster_labels, wallet_explorer_wallet_labels, wallet_explorer_category_labels):
    fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=True)
    
    reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42, n_jobs=1)
    embeddings_2d = reducer.fit_transform(graph_embeddings)

    args = {
        "x": embeddings_2d[:, 0],
        "y": embeddings_2d[:, 1],
        "palette": "tab10",
        "legend": "full",
        "s": 50,
        "alpha": 0.8
    }

    sns.scatterplot(**args, hue=graph_cluster_labels, ax=axes[0])
    sns.scatterplot(**args, hue=wallet_explorer_wallet_labels, ax=axes[1])
    sns.scatterplot(**args, hue=wallet_explorer_category_labels, ax=axes[2])

    axes[0].set_title('Graph Cluster Labels')
    axes[1].set_title('WalletExplorer Wallet Labels')
    axes[2].set_title('WalletExplorer Category Labels')

    axes[0].set_xlabel('UMAP 1')
    axes[0].set_ylabel('UMAP 2')
    axes[1].set_xlabel('UMAP 2')
    axes[2].set_xlabel('UMAP 2')

    for ax in axes:
        ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=3, frameon=False)
    
    plt.subplots_adjust(wspace=0.05)
    plt.show()

In [None]:
visualize_umap_multiple_label_groups(embeddings_for_analysis, final_clusters, wallet_explorer_clusters_wallet_label, wallet_explorer_clusters_wallet_category)

## Comparing Communities by Graph-Level Features

You may also characterize the sampled communities using their graph-level features. 
These include statistics that summarize the entire sampled subgraph, 
such as the average block height or BTC value calculated per hop from the community's root node.

However, it is crucial to understand a key distinction when using these for evaluation.

The model we trained here performs node-level representation learning. 
It is designed to generate a feature vector (an embedding) 
for the root node of each sampled community, based on its neighborhood. 
In contrast, the statistics described above are graph-level 
features that describe the entire community.

Therefore, clustering the root node embeddings and then evaluating 
those clusters based on graph-level features is a methodological mismatch,
and can lead to misleading conclusions.

For such a comparison to be valid, 
the model must be adapted for graph-level representation learning. 
This typically involves adding a readout or pooling function 
(e.g., global mean pooling) after the GNN layers 
to aggregate all node embeddings into a single embedding for the entire graph
(or other methods).

In [None]:
def format_block_height_axis(ax):
    formatter = FuncFormatter(lambda x, pos: f'{x / 1000:.0f}k')
    ax.xaxis.set_major_formatter(formatter)

In [None]:
def prepare_stats_for_plotting(stats_list, graph_ids):
    plot_data = []
    for i, stats_dict in enumerate(stats_list):
        graph_id = graph_ids[i]
        for hop_level, hop_stats in stats_dict.items():
            hop_transition_key = f"hop{hop_level}->hop{hop_level+1}"
            
            if "Value_avg" in hop_stats:
                plot_data.append({
                    "graph_id": graph_id,
                    "hop_transition": hop_transition_key,
                    "avg_value": hop_stats["Value_avg"],
                    "metric_type": "Avg Value"
                })
            if "BlockHeight_avg" in hop_stats:
                plot_data.append({
                    "graph_id": graph_id,
                    "hop_transition": hop_transition_key,
                    "avg_block_height": hop_stats["BlockHeight_avg"],
                    "metric_type": "Avg BlockHeight"
                })                
            if "OriginalInDegree_avg" in hop_stats:
                plot_data.append({
                    "graph_id": graph_id,
                    "hop_transition": hop_transition_key,
                    "avg_original_in_degree": hop_stats["OriginalInDegree_avg"],
                    "metric_type": "Avg Original In-Degree"
                })
            if "OriginalOutDegree_avg" in hop_stats:
                plot_data.append({
                    "graph_id": graph_id,
                    "hop_transition": hop_transition_key,
                    "avg_original_out_degree": hop_stats["OriginalOutDegree_avg"],
                    "metric_type": "Avg Original Out-Degree"
                })
                
    return pd.DataFrame(plot_data)

In [None]:
def compare_graph_stats(stats_list, graph_ids):
    df = prepare_stats_for_plotting(stats_list, graph_ids)
    hop_order = sorted(df["hop_transition"].dropna().unique(), key=lambda x: int(x.split("->")[0][3:]))
    color_palette = sns.color_palette("viridis", len(graph_ids))

    fig, axes = plt.subplots(1, 4, figsize=(18, 4), sharey=True)

    ax1 = axes[0]
    ax2 = axes[1]
    ax3 = axes[2]
    ax4 = axes[3]

    sns.barplot(data=df[df["metric_type"] == "Avg Value"], 
                y="hop_transition", x="avg_value", hue="graph_id", 
                ax=ax1, order=hop_order, orient="h", palette=color_palette)
    ax1.set_title("Avg. Tx Value by Hop")
    ax1.set_xlabel("Average Value (Log10 BTC)")
    ax1.set_ylabel("Hop Transition")
    ax1.set_xscale("log")
    ax1.get_legend().remove()

    sns.barplot(data=df[df["metric_type"] == "Avg BlockHeight"], 
                y="hop_transition", x="avg_block_height", hue="graph_id", 
                ax=ax2, order=hop_order, orient="h", palette=color_palette)
    ax2.set_title("Avg. Block Height by Hop")
    ax2.set_xlabel("Average Block Height")
    ax2.set_ylabel("")
    ax2.get_legend().remove()
    format_block_height_axis(ax2)
    
    sns.barplot(data=df[df["metric_type"] == "Avg Original In-Degree"], 
                y="hop_transition", x="avg_original_in_degree", hue="graph_id", 
                ax=ax3, order=hop_order, orient="h", palette=color_palette)
    ax3.set_title("Avg. Original In-Degree by Hop")
    ax3.set_xlabel("Average Original In-Degree")
    ax3.set_ylabel("")
    ax3.get_legend().remove()

    sns.barplot(data=df[df["metric_type"] == "Avg Original Out-Degree"], 
                y="hop_transition", x="avg_original_out_degree", hue="graph_id", 
                ax=ax4, order=hop_order, orient="h", palette=color_palette)
    ax4.set_title("Avg. Original Out-Degree by Hop")
    ax4.set_xlabel("Average Original Out-Degree")
    ax4.set_ylabel("")
    ax4.get_legend().remove()

    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

In [None]:
def get_graphs_from_cluster(target_cluster_id, all_cluster_labels, dataset):
    indices = np.where(all_cluster_labels == target_cluster_id)[0]
    
    if len(indices) == 0:
        print(f"No graphs found for cluster ID {target_cluster_id}.")
        return []
        
    graphs = [dataset[i] for i in indices]
    return graphs

In [None]:
a_g_id = get_graphs_from_cluster(1, final_clusters, test_dataset)[0].graph_id
b_g_id = get_graphs_from_cluster(1, final_clusters, test_dataset)[1].graph_id
compare_graph_stats([dataset.per_graph_stats[a_g_id], dataset.per_graph_stats[b_g_id]], [a_g_id, b_g_id])