# GraphSAGE

## Imports

In [1]:
!nvidia-smi

Fri Nov 14 15:23:10 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA L4                      Off |   00000000:00:03.0 Off |                    0 |
| N/A   76C    P0             36W /   72W |       0MiB /  23034MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
!pip install dgl -f https://data.dgl.ai/wheels/torch-2.4/cu124/repo.html # Install appropriate verson see - https://www.dgl.ai/pages/start.html

Looking in links: https://data.dgl.ai/wheels/torch-2.4/cu124/repo.html


In [3]:
import torch, traceback, gc
import torch.nn as nn
import torch.nn.functional as F
import dgl
import numpy as np
import pandas as pd
from typing import List, Tuple, Optional

## Create Aggregators

In [4]:
class MeanAggregator(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
      super().__init__()
      self.input_dim = input_dim
      self.output_dim = output_dim
      self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, neighbor_features: torch.Tensor) -> torch.Tensor:
      # neighbor_features shape: [batch_size, num_neighbors, input_dim]
      # Compute mean along the neighbor dimension
      aggregated = torch.mean(neighbor_features, dim=1)  # [batch_size, input_dim]
      return self.linear(aggregated)

class MaxPoolAggregator(nn.Module): # No difference with min see paper - https://arxiv.org/pdf/1706.02216
    def __init__(self, input_dim: int, output_dim: int, hidden_dim: int = None):
      super().__init__()
      self.input_dim = input_dim
      self.output_dim = output_dim
      if hidden_dim is None:
          hidden_dim = output_dim

      self.mlp = nn.Sequential(
          nn.Linear(input_dim, hidden_dim),
          nn.ReLU(),
          nn.Linear(hidden_dim, output_dim)
      )

    def forward(self, neighbor_features: torch.Tensor) -> torch.Tensor:
        # neighbor_features shape: [batch_size, num_neighbors, input_dim]
        batch_size, num_neighbors, input_dim = neighbor_features.shape

        # Reshape to apply MLP to each neighbor
        neighbor_features = neighbor_features.view(-1, input_dim)  # [batch_size * num_neighbors, input_dim]
        neighbor_features = self.mlp(neighbor_features)  # [batch_size * num_neighbors, output_dim]

        # Reshape back and apply max pooling
        neighbor_features = neighbor_features.view(batch_size, num_neighbors, -1)
        aggregated = torch.max(neighbor_features, dim=1)[0]  # [batch_size, output_dim]

        return aggregated

class LSTMAggregator(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, hidden_dim: int = None):
      super().__init__()
      self.input_dim = input_dim
      self.output_dim = output_dim
      if hidden_dim is None:
          hidden_dim = output_dim

      self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
      self.linear = nn.Linear(hidden_dim, output_dim)

    def forward(self, neighbor_features: torch.Tensor) -> torch.Tensor:
        # neighbor_features shape: [batch_size, num_neighbors, input_dim]
        # LSTM expects input in the same format
        _, (h_n, _) = self.lstm(neighbor_features)

        # Use the last hidden state
        h_n = h_n.squeeze(0)  # [batch_size, hidden_dim]
        return self.linear(h_n)


In [5]:
class GraphSAGELayer(nn.Module):
  def __init__(
    self,
    input_dim: int,
    output_dim: int,
    aggregator_type: str = 'mean',
    num_samples: int = 10,
    dropout: float = 0.5,
    normalize: bool = True):

    super().__init__()

    self.input_dim = input_dim
    self.output_dim = output_dim
    self.num_samples = num_samples
    self.normalize = normalize
    self.dropout = nn.Dropout(dropout)

    # Initialize aggregator
    if aggregator_type == 'mean':
        self.aggregator = MeanAggregator(input_dim, output_dim)
    elif aggregator_type == 'maxpool':
        self.aggregator = MaxPoolAggregator(input_dim, output_dim)
    elif aggregator_type == 'lstm':
        self.aggregator = LSTMAggregator(input_dim, output_dim)
    else:
        raise ValueError(f"Unknown aggregator type: {aggregator_type}")

    # Self transformation
    self.self_linear = nn.Linear(input_dim, output_dim)

    # Final transformation (concatenation of self + neighbor)
    self.final_linear = nn.Linear(output_dim * 2, output_dim)

  def sample_neighbors(self, graph: dgl.DGLGraph, nodes: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
      # Use DGL's sampling functionality
      sampled_graph = dgl.sampling.sample_neighbors(
          graph, nodes, self.num_samples, replace=True
      )

      # Get source and destination nodes
      src, dst = sampled_graph.edges()
      return src, dst

  def forward(self, graph: dgl.DGLGraph, features: torch.Tensor, nodes: torch.Tensor) -> torch.Tensor:
      # if nodes is None:
      #     nodes = torch.arange(graph.number_of_nodes(), device=features.device)

      batch_size = len(nodes)
      print(f"Batch Size: {batch_size}")

      # Sample neighbors for each target node
      src_nodes, dst_nodes = self.sample_neighbors(graph, nodes)

      # Group neighbors by destination node
      neighbor_features_list = []

      for i, node in enumerate(nodes):
          # Find neighbors of current node
          neighbor_mask = (dst_nodes == node)
          node_neighbors = src_nodes[neighbor_mask]

          if len(node_neighbors) == 0:
              # If no neighbors, use zero features
              neighbor_feats = torch.zeros(self.num_samples, self.input_dim, device=features.device)
          else:
              # Get neighbor features
              neighbor_feats = features[node_neighbors]

              # Pad or truncate to max neighbours
              # If neighbours < num_samples, pad with features from random existing neighbours
              # Else if neighbours > num_samples, we randomly pick X(num_samples) number of neighbour features
              if len(neighbor_feats) < self.num_samples:
                  # Pad with repetition of existing neighbors
                  indices = torch.randint(0, len(neighbor_feats), (self.num_samples - len(neighbor_feats),))
                  padding = neighbor_feats[indices]
                  neighbor_feats = torch.cat([neighbor_feats, padding], dim=0)
              elif len(neighbor_feats) > self.num_samples:
                  # Random sampling
                  indices = torch.randperm(len(neighbor_feats))[:self.num_samples]
                  neighbor_feats = neighbor_feats[indices]

          neighbor_features_list.append(neighbor_feats)

      # Stack neighbor features
      neighbor_features = torch.stack(neighbor_features_list, dim=0)  # [batch_size, num_samples, input_dim]

      # Apply dropout
      neighbor_features = self.dropout(neighbor_features)

      # Aggregate neighbor features
      aggregated_neighbors = self.aggregator(neighbor_features)  # [batch_size, output_dim]

      # Transform self features
      self_features = features[nodes]  # [batch_size, input_dim]
      self_features = self.dropout(self_features)
      transformed_self = self.self_linear(self_features)  # [batch_size, output_dim]

      # Concatenate self and neighbor features
      combined = torch.cat([transformed_self, aggregated_neighbors], dim=1)  # [batch_size, 2*output_dim]

      # Final transformation
      output = self.final_linear(combined)  # [batch_size, output_dim]

      # Apply activation
      output = F.relu(output)

      # L2 normalization
      if self.normalize:
          output = F.normalize(output, p=2, dim=1)

      return output

class GraphSAGE(nn.Module):
  def __init__(
      self,
      input_dim: int,
      hidden_dims: List[int],
      output_dim: int,
      aggregator_type: str = 'mean',
      # num_samples: List[int] = None,
      num_samples: int = 2,
      dropout: float = 0.5,
      normalize: bool = True):
    super().__init__()

    # if num_samples is None:
    #     num_samples = [10] * len(hidden_dims)

    # assert len(hidden_dims) == len(num_samples), "Length mismatch between hidden_dims and num_samples"

    self.layers = nn.ModuleList()

    # Concatenate for stacking multiple layers
    dims = [input_dim] + hidden_dims + [output_dim]
    print(dims)

    # Create GraphSAGE layers
    for i in range(len(dims) - 1):
      layer = GraphSAGELayer(
          input_dim=dims[i],
          output_dim=dims[i + 1],
          aggregator_type=aggregator_type,
          # num_samples=num_samples[min(i, len(num_samples) - 1)],
          num_samples=num_samples,
          dropout=dropout if i < len(dims) - 2 else 0,  # No dropout in last layer
          normalize=normalize if i < len(dims) - 2 else False  # No normalization in last layer
      )
      self.layers.append(layer)

  def forward(self, graph: dgl.DGLGraph, features: torch.Tensor, nodes: torch.Tensor) -> torch.Tensor:
    h = features
    for layer in self.layers:
        h_new = layer(graph, h, nodes)

        # For next layer, we need to update the full feature matrix
        if nodes is not None:
            h_full = h.clone()
            h_full[nodes] = h_new
            h = h_full
        else:
            h = h_new

    return h if nodes is None else h[nodes]

class GraphSAGENodeClassifier(nn.Module):
    def __init__(self, input_dim: int, hidden_dims: List[int], num_classes: int, **kwargs):
      super().__init__()
      self.graphsage = GraphSAGE(input_dim, hidden_dims, hidden_dims[-1], **kwargs)
      self.classifier = nn.Linear(hidden_dims[-1], num_classes)

    def forward(self, graph: dgl.DGLGraph, features: torch.Tensor, nodes: torch.Tensor):
      embeddings = self.graphsage(graph, features, nodes)
      return self.classifier(embeddings)

def train(graph, features, labels, nodes, graph_v, features_v, labels_v, nodes_v, epoch=5):
  input_dim = features.shape[1]
  num_classes = 2 # Illicit or Licit only

  # Create model
  model = GraphSAGENodeClassifier(
      input_dim=input_dim,
      hidden_dims=[input_dim, input_dim],
      num_classes=num_classes,
      aggregator_type='maxpool', # Define type of aggregator here - mean, maxpool, lstm
      num_samples=2 # Define the number of neighbours to sample
  )

  if torch.cuda.is_available():
    model = model.cuda()
    features = features.cuda()
    labels = labels.cuda()
    nodes = nodes.cuda()
    graph = graph.to("cuda")

    features_v = features_v.cuda()
    labels_v = labels_v.cuda()
    nodes_v = nodes_v.cuda()
    graph_v = graph_v.to("cuda")

  optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  criterion = nn.CrossEntropyLoss()

  for epoch in range(epoch):
      # Training
      model.train()
      print("TRAINING")
      optimizer.zero_grad()

      logits = model(graph, features, nodes)
      loss = criterion(logits, labels)

      # Backward pass
      loss.backward()
      optimizer.step()

      # Validation
      model.eval()
      print("VALIDATING")
      logits = model(graph_v, features_v, nodes_v)
      val_loss = criterion(logits, labels_v)

      print(f"Epoch {epoch}, Training Loss: {loss.item():.4f}, Validation Loss: {val_loss.item():.4f}")
  print("Returning model...")
  return model


# Dataset Preparation

In [6]:
!pip install kagglehub[pandas-datasets]



In [7]:
import kagglehub
from kagglehub import KaggleDatasetAdapter

# Set the path to the file you'd like to load
# LOAD FEATURES
file_path = "elliptic_bitcoin_dataset/elliptic_txs_features.csv"
data = kagglehub.load_dataset(
  KaggleDatasetAdapter.PANDAS,
  "ellipticco/elliptic-data-set",
  file_path,
  pandas_kwargs={"header": None},
)
# LOAD CLASSES
file_path = "elliptic_bitcoin_dataset/elliptic_txs_classes.csv"
classes = kagglehub.load_dataset(
  KaggleDatasetAdapter.PANDAS,
  "ellipticco/elliptic-data-set",
  file_path,
)
# LOAD EDGELIST
file_path = "elliptic_bitcoin_dataset/elliptic_txs_edgelist.csv"
edgelist = kagglehub.load_dataset(
  KaggleDatasetAdapter.PANDAS,
  "ellipticco/elliptic-data-set",
  file_path,
)
print(data.shape)
print(classes.shape)
print(edgelist.shape)

  data = kagglehub.load_dataset(


Using Colab cache for faster access to the 'elliptic-data-set' dataset.


  classes = kagglehub.load_dataset(


Using Colab cache for faster access to the 'elliptic-data-set' dataset.


  edgelist = kagglehub.load_dataset(


Using Colab cache for faster access to the 'elliptic-data-set' dataset.
(203769, 167)
(203769, 2)
(234355, 2)


In [8]:
# Perform Split
splits = np.load('elliptic_splits.npz', allow_pickle=True)
print(splits)
# data = pd.read_csv('elliptic_txs_features.csv', header=None, low_memory=False)
# edgelist = pd.read_csv('elliptic_txs_edgelist.csv', header=None, low_memory=False)

dataset = []
for item in splits.files:
  print(item)
  print(f"No. of rows: {len(splits[item])}")
  split_data = data[data.index.isin(splits[item])]
  split_classes = classes[classes.index.isin(splits[item])]
  txnIds = split_data[split_data.columns[0]].to_numpy()
  split_edgelist = edgelist[edgelist.iloc[:, 0].isin(txnIds)]
  split_edgelist = split_edgelist[split_edgelist.iloc[:, 1].isin(txnIds)]
  print(split_data.shape)
  print(split_edgelist.shape)

  # Filter nodes that are in edgelist only
  # IMPORTANT: Nodes will be used later, so it has to tally with the nodes in the edgelist
  split_data = split_data[split_data.iloc[:, 0].isin(split_edgelist.iloc[:,0]) | split_data.iloc[:, 0].isin(split_edgelist.iloc[:,1])]
  labels_data = split_classes[split_classes.iloc[:, 0].isin(split_edgelist.iloc[:,0]) | split_data.iloc[:, 0].isin(split_edgelist.iloc[:,1])]
  labels_data = labels_data.iloc[:, -1] # last column

  # Separate nodes, features
  original_nodes = split_data.iloc[:, 0]   # first column (transaction IDs)
  features_data = split_data.iloc[:, 1:-1]   # all columns except last and first column (txn id)

  # Remap Y labels to be either 0 or 1
  labels_data = labels_data[~labels_data.str.contains('unknown')]
  labels_data = labels_data.map({'2': 0, '1': 1})

  # Create a mapping from original txn IDs to new sequential indices
  # Necessary to make it easier to process IDs in graph
  unique_original_nodes = original_nodes.unique()
  node_mapping = {node_id:i for i, node_id in enumerate(unique_original_nodes)}

  # Initialize features and labels
  num_unique_nodes = len(unique_original_nodes)
  features = torch.zeros(num_unique_nodes, features_data.shape[1], dtype=torch.float)
  labels = torch.zeros(num_unique_nodes, dtype=torch.long)

  # Align
  nodes = torch.tensor([node_mapping[node_id] for node_id in original_nodes], dtype=torch.long)
  features[nodes] = torch.tensor(features_data.values, dtype=torch.float)
  labels[nodes] = torch.tensor(labels_data.values, dtype=torch.long)

  print(f"No. of nodes: {num_unique_nodes}") # Should be equal to num of rows in dataset
  print(f"Original nodes shape: {original_nodes.shape}")
  print(f"Original features shape: {features_data.shape}")
  print(f"Original labels shape: {labels_data.shape}")
  print(f"Features shape: {features.shape}")
  print(f"Labels shape: {labels.shape}")

  # Update the edgelist as well to use the new sequential indices
  src_nodes = split_edgelist[split_edgelist.columns[0]].map(node_mapping).values
  dest_nodes = split_edgelist[split_edgelist.columns[1]].map(node_mapping).values
  dataset.append({
      "set": item,
      "features": features,
      "labels": labels,
      "nodes": nodes,
      "src_nodes": src_nodes,
      "dest_nodes": dest_nodes
  })

# print(dataset[0])

NpzFile 'elliptic_splits.npz' with keys: train_idx, val_idx, test_idx
train_idx
No. of rows: 27938
(27938, 167)
(12882, 2)
No. of nodes: 15671
Original nodes shape: (15671,)
Original features shape: (15671, 165)
Original labels shape: (15671,)
Features shape: torch.Size([15671, 165])
Labels shape: torch.Size([15671])
val_idx
No. of rows: 9313
(9313, 167)
(1415, 2)
No. of nodes: 2227
Original nodes shape: (2227,)
Original features shape: (2227, 165)
Original labels shape: (2227,)
Features shape: torch.Size([2227, 165])
Labels shape: torch.Size([2227])
test_idx
No. of rows: 9313
(9313, 167)
(1550, 2)
No. of nodes: 2355
Original nodes shape: (2355,)
Original features shape: (2355, 165)
Original labels shape: (2355,)
Features shape: torch.Size([2355, 165])
Labels shape: torch.Size([2355])


In [9]:
# TRAIN MODEL

# Train Dataset
trainDS = dataset[0]
features = trainDS["features"]
labels = trainDS["labels"]
nodes = trainDS["nodes"]
src = torch.tensor(trainDS["src_nodes"], dtype=torch.long)
dst = torch.tensor(trainDS["dest_nodes"], dtype=torch.long)
print(features)
print(labels)
print(nodes)
print(src)
print(dst)

print("Label distribution:", torch.bincount(labels))

graph = dgl.graph((src, dst))
graph = dgl.add_self_loop(graph)

print(f"Number of nodes: {graph.number_of_nodes()}, Number of edges: {graph.number_of_edges()}")

# Validation Dataset
validationDS = dataset[1]
features_v = trainDS["features"]
labels_v = trainDS["labels"]
nodes_v = trainDS["nodes"]
src_v = torch.tensor(trainDS["src_nodes"], dtype=torch.long)
dst_v = torch.tensor(trainDS["dest_nodes"], dtype=torch.long)
graph_v = dgl.graph((src_v, dst_v))
graph_v = dgl.add_self_loop(graph_v)

print("\nSTARTING TRAINING...")
epoch = 100
model = train(graph, features, labels, nodes, graph_v, features_v, labels_v, nodes_v, epoch)
print("\nEND OF TRAINING")

tensor([[ 1.0000e+00,  1.6305e-01,  1.9638e+00,  ..., -1.3116e-01,
          6.7780e-01, -1.2061e-01],
        [ 1.0000e+00, -5.0271e-03,  5.7894e-01,  ..., -1.3116e-01,
          3.3321e-01, -1.2061e-01],
        [ 1.0000e+00, -1.5136e-01, -1.8467e-01,  ..., -1.3116e-01,
         -9.7524e-02, -1.2061e-01],
        ...,
        [ 4.9000e+01, -1.7249e-01, -9.0143e-02,  ..., -1.3116e-01,
         -9.7524e-02, -1.2061e-01],
        [ 4.9000e+01, -1.8297e-02, -1.1549e-01,  ..., -1.3116e-01,
         -9.7524e-02, -1.2061e-01],
        [ 4.9000e+01, -1.7041e-01, -7.8164e-02,  ..., -1.3116e-01,
         -9.7524e-02, -1.2061e-01]])
tensor([0, 0, 0,  ..., 1, 0, 1])
tensor([    0,     1,     2,  ..., 15668, 15669, 15670])
tensor([    9,    11,    13,  ..., 15523, 15662, 15661])
tensor([   10,    12,    11,  ..., 15662, 15478, 15489])
Label distribution: tensor([14702,   969])
Number of nodes: 15671, Number of edges: 28553

STARTING TRAINING...
[165, 165, 165, 165]
TRAINING
Batch Size: 15671
Batc

In [10]:
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

# TEST
testDS = dataset[2]
features = testDS["features"]
labels = testDS["labels"]
nodes = testDS["nodes"]
src = torch.tensor(testDS["src_nodes"], dtype=torch.long)
dst = torch.tensor(testDS["dest_nodes"], dtype=torch.long)
print(features)
print(labels)
print(nodes)
print(src)
print(dst)

graph = dgl.graph((src, dst))
graph = dgl.add_self_loop(graph)

print(f"Number of nodes: {graph.number_of_nodes()}, Number of edges: {graph.number_of_edges()}")

print("\nTESTING")

model.eval()
print("Label distribution:", torch.bincount(labels))
labels = labels.detach().numpy()

if torch.cuda.is_available():
  features = features.cuda()
  nodes = nodes.cuda()
  graph = graph.to("cuda")

outputs = model(graph, features, nodes).cpu().argmax(dim=1).detach().numpy()
print("Pred distribution:", torch.bincount(model(graph, features, nodes).argmax(dim=1)))

acc = accuracy_score(labels, outputs)
prec = precision_score(labels, outputs)
rec = recall_score(labels, outputs)
f1  = f1_score(labels, outputs)
print("Accuracy: ", acc)
print("Precision: ", prec)
print("Recall: ", rec)
print("F1 Score: ", f1)


tensor([[ 1.0000, -0.1729, -0.1847,  ..., -0.0932, -0.0688, -0.1206],
        [ 1.0000, -0.1729, -0.1847,  ..., -0.1312, -0.0975, -0.1206],
        [ 1.0000,  0.0923,  1.2390,  ..., -0.1312,  0.0748, -0.1206],
        ...,
        [49.0000, -0.0942, -0.1162,  ..., -0.1312, -0.0975, -0.1206],
        [49.0000,  0.7024, -0.1227,  ..., -0.1312, -0.0975, -0.1206],
        [49.0000,  0.7033, -0.1202,  ..., -0.1312, -0.0975, -0.1206]])
tensor([0, 0, 0,  ..., 0, 0, 0])
tensor([   0,    1,    2,  ..., 2352, 2353, 2354])
tensor([   0,    3,   11,  ..., 2350, 2339, 2347])
tensor([   1,    6,   12,  ..., 2342, 2342, 2344])
Number of nodes: 2355, Number of edges: 3905

TESTING
Label distribution: tensor([2212,  143])
Batch Size: 2355
Batch Size: 2355
Batch Size: 2355
Batch Size: 2355
Batch Size: 2355
Batch Size: 2355
Pred distribution: tensor([2209,  146], device='cuda:0')
Accuracy:  0.9673036093418259
Precision:  0.72
Recall:  0.7552447552447552
F1 Score:  0.7372013651877133


## Setup
LR: 0.001

GraphSage Layers: 2

Num Neighbours: 2

----
## Split with illicit/licit nodes
| Type    | train loss | validation loss | Accuracy | Precision | Recall | F1 Score
| -------- | ------- | ------- | ------- | ------- | ------- | ------- |
| GraphSAGE-Mean   |   0.1036  |  0.0906   |   0.972  |  0.829   |  0.678   |  0.746   |
| GraphSAGE-MaxPool    |  0.1043  |  0.0987   |  0.967   |   0.72  |   0.755  |  0.737   |
| GraphSAGE-LSTM    |  0.0843  |  0.0733   |   0.975  |  0.85   |  0.713   |  0.776   |
