This notebook provides tools for profiling on dataset and augmentation method.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# !pip install -r requirements.txt

In [None]:
import math
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from torch_geometric.data import Data
import seaborn as sns
from pprint import pprint
import json

from script_classification.data_loader import BitcoinScriptsDataset
from script_classification.view_augmenter import ViewAugmenter

In [None]:
sns.set_theme()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
config = json.load(open("script_classification/config.json"))
data_root = config["data_root"]

In [None]:
dataset = BitcoinScriptsDataset(root=data_root)

graph_id_to_idx_map = {}
for i in range(len(dataset)):
    d = dataset.get(i)
    gid = d.graph_id
    try:
        key = gid
        graph_id_to_idx_map[key] = i
    except TypeError:
        graph_id_to_idx_map[str(gid)] = i

In [None]:
print("Node features:")
pprint(dataset.node_feature_names)

print("\nEdge features:")
pprint(dataset.edge_feature_names)

In [None]:
sample_idx = 0
# sample_idx = graph_id_to_idx_map["202508171811176566"]

In [None]:
data = dataset.get(sample_idx)
print(f"graph_id:\t{getattr(data, 'graph_id', sample_idx)}")
print(f"Node count:\t{data.num_nodes}")
print(f"Edge count:\t{data.edge_index.size(1)}")

In [None]:
data_dev = data.to(device)

In [None]:
augmenter = ViewAugmenter(
    block_height_col=1,
    block_height_scale_range=(0.99, 1.01),
    block_height_shift_range=(-3.0, 3.0),
    degree_cols=(2, 3),
    degree_jitter=0.05,
    value_col=0,
    value_jitter=0.05 # +/-25% jitter
).to(device)

In [None]:
augmenter.eval()
with torch.no_grad():
    v1, v2 = augmenter(data_dev)

In [None]:
# move to CPU numpy for plotting
def to_np(t):
    return t.detach().cpu().numpy() if torch.is_tensor(t) else np.asarray(t)

In [None]:
x_original = to_np(data.x)
x_aug_v1 = to_np(v1.x)
x_aug_v2 = to_np(v2.x)
edge_attribute_original = to_np(data.edge_attr) if getattr(data, "edge_attr", None) is not None else None
edge_attribute_aug_v1  = to_np(v1.edge_attr) if getattr(v1, "edge_attr", None) is not None else None
edge_attribute_aug_v2  = to_np(v2.edge_attr) if getattr(v2, "edge_attr", None) is not None else None

In [None]:
def plot_histogram(ax, original_values, v1_values, v2_values, x_label):
    ax.hist(original_values, density=True, alpha=0.5, label="Original")
    ax.hist(v1_values, density=True, alpha=0.5, label="Augmented V1")
    ax.hist(v2_values, density=True, alpha=0.5, label="Augmented V2")
    ax.set_xlabel(x_label)
    ax.set_ylabel("Density")
    # ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=3, frameon=False)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 4))
axes = axes.ravel()

plot_histogram(
    axes[0], 
    edge_attribute_original[:, 0], edge_attribute_aug_v1[:, 0], edge_attribute_aug_v2[:, 0], 
    dataset.edge_feature_names[0]
)

plot_histogram(
    axes[1], 
    edge_attribute_original[:, 1], edge_attribute_aug_v1[:, 1], edge_attribute_aug_v2[:, 1], 
    dataset.edge_feature_names[1]
)

handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, loc="upper center", bbox_to_anchor=(0.5, -0.02), ncol=3, frameon=False)
fig.subplots_adjust(bottom=0.15)
plt.show()

In [None]:
cols = x_original.shape[1]
cols_count = min(3, cols)
rows_count = math.ceil(cols / cols_count)
fig, axes = plt.subplots(rows_count, cols_count, figsize=(5*cols_count, 3*rows_count), squeeze=False)
for j in range(cols):
    r, c = divmod(j, cols_count)
    name = dataset.node_feature_names[j] if j < len(dataset.node_feature_names) else f"node_feat_{j}"
    plot_histogram(axes[r][c], x_original[:, j], x_aug_v1[:, j], x_aug_v2[:, j], f"Node: {name}")
for k in range(cols, rows_count*cols_count):
    r, c = divmod(k, cols_count)
    axes[r][c].axis("off")

handles, labels = axes[0][0].get_legend_handles_labels()
fig.legend(handles, labels, loc="upper center", bbox_to_anchor=(0.5, -0.02), ncol=3, frameon=False)
fig.subplots_adjust(bottom=0.15)
plt.tight_layout()
plt.show()