## GNN-Based Outlier Detection Model

To set up the environment for the GNN-based outlier detection model, follow these steps to install PyTorch and PyTorch Geometric dependencies:

1. Install **PyTorch** with its compatible versions of **torchvision** and **torchaudio** for CUDA 12.4 support:
    ```bash
    pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124
    ```

2. Install PyG libraries (`pyg-lib`, `torch_scatter`, `torch_sparse`, `torch_cluster`, and `torch_spline_conv`):
    ```bash
    pip install pyg-lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.4.0+cu124.html
    ```

3. Install the **torch_geometric** package:
    ```bash
    pip install torch_geometric
    ```

In [1]:
import gc
import shutil
import numpy as np
import pandas as pd
import torch
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.utils import subgraph
from sklearn.preprocessing import MinMaxScaler
from pyspark import SparkConf
from pyspark.sql import Window
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from functools import reduce

# Instantiate PySpark session
config = [
    ("spark.driver.memory", "64g"), 
    ("spark.executor.memory", "64g"),
    ("spark.driver.maxResultSize", "64g"),
    ("spark.sql.session.timeZone", "UTC")
]
spark = SparkSession.builder.appName("07_gnn_based_model").config(conf=SparkConf().setAll(config)).getOrCreate()

In [3]:
DATASET = "HI-Small"    ## either HI-Small or LI-Small
nodes_location = f"../datasets/synthetic/06_temporal_graph/{DATASET}_nodes"
edges_location = f"../datasets/synthetic/06_temporal_graph/{DATASET}_edges"
trans_location = f"../datasets/synthetic/02_preprocessed/{DATASET}-transactions.parquet"

In [None]:
# Read node and edges of temporal graph of sequential transactions
nodes = spark.read.parquet(nodes_location)
edges = spark.read.parquet(edges_location)

In [6]:
# Read all transactions and select transaction_id and amount columns 
transactions = spark.read.parquet(trans_location)
amount_col = transactions.select("transaction_id", "amount")

In [7]:
# Join edges with the amount for the source and for the target
edges = edges.join(amount_col, amount_col.transaction_id==edges.src, how='left').drop('transaction_id')
edges = edges.withColumnRenamed('amount', 'src_amount')
edges = edges.join(amount_col, amount_col.transaction_id==edges.dst, how='left').drop('transaction_id')
edges = edges.withColumnRenamed('amount', 'dst_amount')

edges = edges.drop('src_date', 'dst_date')

In [8]:
# Select flow, fan_in and fan_out edges
flow_edges = edges.filter(F.col('edge_type')=='flow').drop('edge_type')
fan_in_edges = edges.filter(F.col('edge_type')=='fan_in').drop('edge_type')
fan_out_edges = edges.filter(F.col('edge_type')=='fan_out').drop('edge_type')

del edges

In [9]:
# Calculate node attributes from edges
# ============== SOURCE ==============

window_spec = Window.partitionBy('src')

# attr_1, attr_2
flow_edges_src = flow_edges.withColumn('sum_amount', F.sum('dst_amount').over(window_spec)) \
                       .withColumn('median_delta', F.median('delta').over(window_spec))
flow_edges_src = flow_edges_src.withColumn('att_1', F.round(F.least(F.col('src_amount') / F.col('sum_amount'), F.lit(1)), 6)) \
                       .withColumn('att_2', F.round(F.col('median_delta'), 6)) \
                       .drop('sum_amount', 'median_delta')

# attr_3, attr_4
fan_in_edges_src = fan_in_edges.withColumn('median_amount', F.median('dst_amount').over(window_spec)) \
                           .withColumn('median_delta', F.median('delta').over(window_spec))
fan_in_edges_src = fan_in_edges_src.withColumn('att_3', F.round((F.abs(F.col('src_amount') - F.col('median_amount')) / 
                                                         F.greatest(F.col('src_amount'), F.col('median_amount'))), 6)) \
                           .withColumn('att_4', F.round(F.col('median_delta'), 6)) \
                           .drop('median_amount', 'median_delta')

# attr_5, attr_6
fan_out_edges_src = fan_out_edges.withColumn('median_amount', F.median('dst_amount').over(window_spec)) \
                             .withColumn('median_delta', F.median('delta').over(window_spec))
fan_out_edges_src = fan_out_edges_src.withColumn('att_5', F.round((F.abs(F.col('src_amount') - F.col('median_amount')) / 
                                                          F.greatest(F.col('src_amount'), F.col('median_amount'))), 6)) \
                             .withColumn('att_6', F.round(F.col('median_delta'), 6)) \
                             .drop('median_amount', 'median_delta')

# Calculate node attributes from edges
# ============== TARGET ==============

window_spec = Window.partitionBy('dst')

# attr_7, attr_8
flow_edges_dst = flow_edges.withColumn('sum_amount', F.sum('src_amount').over(window_spec)) \
                       .withColumn('median_delta', F.median('delta').over(window_spec))
flow_edges_dst = flow_edges_dst.withColumn('att_7', F.round(F.least(F.col('dst_amount') / F.col('sum_amount'), F.lit(1)), 6)) \
                       .withColumn('att_8', F.round(F.col('median_delta'), 6)) \
                       .drop('sum_amount', 'median_delta')

# attr_9, attr_10
fan_in_edges_dst = fan_in_edges.withColumn('median_amount', F.median('src_amount').over(window_spec)) \
                           .withColumn('median_delta', F.median('delta').over(window_spec))
fan_in_edges_dst = fan_in_edges_dst.withColumn('att_9', F.round((F.abs(F.col('dst_amount') - F.col('median_amount')) / 
                                                         F.greatest(F.col('dst_amount'), F.col('median_amount'))), 6)) \
                           .withColumn('att_10', F.round(F.col('median_delta'), 6)) \
                           .drop('median_amount', 'median_delta')

# attr_11, attr_12
fan_out_edges_dst = fan_out_edges.withColumn('median_amount', F.median('src_amount').over(window_spec)) \
                             .withColumn('median_delta', F.median('delta').over(window_spec))
fan_out_edges_dst = fan_out_edges_dst.withColumn('att_11', F.round((F.abs(F.col('dst_amount') - F.col('median_amount')) / 
                                                          F.greatest(F.col('dst_amount'), F.col('median_amount'))), 6)) \
                             .withColumn('att_12', F.round(F.col('median_delta'), 6)) \
                             .drop('median_amount', 'median_delta')

In [10]:
# Select only the relevant attributes from edges

# ============== SOURCE ==============
flow_edges_attr_src = flow_edges_src.select('src', 'att_1', 'att_2').dropDuplicates(['src'])
fan_in_edges_attr_src = fan_in_edges_src.select('src', 'att_3', 'att_4').dropDuplicates(['src'])
fan_out_edges_attr_src = fan_out_edges_src.select('src', 'att_5', 'att_6').dropDuplicates(['src'])

# ============== TARGET ==============
flow_edges_attr_dst = flow_edges_dst.select('dst', 'att_7', 'att_8').dropDuplicates(['dst'])
fan_in_edges_attr_dst = fan_in_edges_dst.select('dst', 'att_9', 'att_10').dropDuplicates(['dst'])
fan_out_edges_attr_dst = fan_out_edges_dst.select('dst', 'att_11', 'att_12').dropDuplicates(['dst'])

In [11]:
node_ids_col = nodes.select(F.col('id'))

In [12]:
edges_list_src = [
    flow_edges_attr_src,
    fan_in_edges_attr_src,
    fan_out_edges_attr_src
]

src_attr = reduce(
    lambda df1, df2: df1.join(df2, df1.id == df2.src, how="left").drop("src").fillna(0),
    edges_list_src,
    node_ids_col
)

edges_list_dst = [
    flow_edges_attr_dst,
    fan_in_edges_attr_dst,
    fan_out_edges_attr_dst
]

dst_attr = reduce(
    lambda df1, df2: df1.join(df2, df1.id == df2.dst, how="left").drop("dst").fillna(0),
    edges_list_dst,
    node_ids_col
)

In [13]:
src_attr_location = f"../datasets/synthetic/06_gnn_model/{DATASET}_nodes_src"
dst_attr_location = f"../datasets/synthetic/06_gnn_model/{DATASET}_nodes_dst"
shutil.rmtree(src_attr_location, ignore_errors=True)
shutil.rmtree(dst_attr_location, ignore_errors=True)

In [None]:
# Save the attributes 1-6 for the source
partitions = 100

src_attr.repartition(partitions) \
          .write.mode("overwrite") \
          .parquet(src_attr_location)

In [None]:
gc.collect()

In [None]:
# Save the attributes 7-12 for the target
dst_attr.repartition(partitions) \
          .write.mode("overwrite") \
          .parquet(dst_attr_location)

In [None]:
gc.collect()

In [19]:
del flow_edges_attr_src, flow_edges_attr_dst, fan_in_edges_attr_src, fan_in_edges_attr_dst, fan_out_edges_attr_src, fan_out_edges_attr_dst
del src_attr, dst_attr

In [20]:
src_attr = spark.read.parquet(src_attr_location)
dst_attr = spark.read.parquet(dst_attr_location)

In [22]:
# Calculate the node attributes by joining the 12 attributes
nodes_attr = src_attr.join(dst_attr, src_attr.id==dst_attr.id,how='inner').drop(dst_attr.id)

In [23]:
nodes_location_save = f"../datasets/synthetic/06_gnn_model/{DATASET}_nodes_gnn"
shutil.rmtree(nodes_location_save, ignore_errors=True)

In [None]:
# Write the node attributes
partitions = 200

nodes_attr.repartition(partitions) \
          .write.mode("overwrite") \
          .parquet(nodes_location_save)

In [25]:
del nodes_attr, src_attr, dst_attr

In [None]:
gc.collect()

In [27]:
# Read edges
edges_location = f"../datasets/synthetic/05_temporal_graph/{DATASET}_edges"
edges = spark.read.parquet(edges_location)

In [29]:
edges = edges.select(
    F.col('src').alias('source'),
    F.col('dst').alias('target'),
    F.col('weight'),
    F.col('delta'),
    F.col('edge_type'),
    F.col('src_date'),
    F.col('dst_date')
)

In [30]:
# Get dummies [0,1] for edge_type
edges = edges.withColumn("is_flow", F.when(F.col("edge_type") == "flow", 1).otherwise(0)) \
                    .withColumn("is_fan_in", F.when(F.col("edge_type") == "fan_in", 1).otherwise(0)) \
                    .withColumn("is_fan_out", F.when(F.col("edge_type") == "fan_out", 1).otherwise(0))

In [None]:
edges_location_save = f"../datasets/synthetic/06_gnn_model/{DATASET}_edges_gnn"
partition_by = ["src_date", "dst_date"]
edges.repartition(*partition_by).write.partitionBy(*partition_by).mode("overwrite").parquet(edges_location_save)
print("edges_written")

In [None]:
del edges; gc.collect()

In [None]:
# Read scores of Isolation Forest (IF) model
# select the cutoff point and divide between normal and abnormal transaction ids

if_scores_location = f"../datasets/synthetic/04_if_output/{DATASET}_if_scores.csv"
trans_location = f"../datasets/synthetic/02_preprocessed/{DATASET}-transactions.parquet"


scores = pd.read_csv(if_scores_location)
scores['transaction_id'] = scores['transaction_id'].astype(str)

normal_percentage = 70
threshold = scores['scores'].quantile(normal_percentage / 100)
normal_ids = list(scores[scores['scores'] < threshold]['transaction_id'].values)
anomalous_ids = list(scores[scores['scores'] >= threshold]['transaction_id'].values)

# Remove anomalous ids from the normal set and reinsert in the anomalous set
transactions = pd.read_parquet(trans_location)
real_laundering_ids = list(transactions[transactions['is_laundering']==1]['transaction_id'].values)

normal_ids_set = set(normal_ids)
anomalous_ids_set = set(anomalous_ids)
real_laundering_ids_set = set(real_laundering_ids)
ids_to_remove = normal_ids_set.intersection(real_laundering_ids_set)

normal_ids_set.difference_update(ids_to_remove)
anomalous_ids_set.update(ids_to_remove)

normal_ids = list(normal_ids_set)
anomalous_ids = list(anomalous_ids_set)

In [None]:
nodes_location_save = f"../datasets/synthethic/06_gnn_model/{DATASET}_nodes_gnn"
edges_location_save = f"../datasets/synthethic/06_gnn_model/{DATASET}_edges_gnn"

nodes = pd.read_parquet(nodes_location_save)
edges = pd.read_parquet(edges_location_save)

In [6]:
edges = edges[['source', 'target', 'weight']]

In [None]:
gc.collect()

In [None]:
# Create a global mapping for node ids, create the entire graph structure using
# node features, edge weight and the edge index

global_node_mapping = {node_id: idx for idx, node_id in enumerate(nodes['id'])}

## =================== ALL GRAPH NODE FEATURES =====================
scaler = MinMaxScaler()
node_features = nodes.drop(columns=['id']).values 
all_node_features = scaler.fit_transform(node_features)
all_node_features = torch.tensor(all_node_features, dtype=torch.float)

## =================== EDGE WEIGHT =====================
weight = np.array(edges[['weight']])
edge_weight = torch.tensor(weight, dtype=torch.float)

## =================== EDGE INDEX ================== 
edges_ind = edges[['source', 'target']]
edge_index_np = np.array([
    edges_ind['source'].map(global_node_mapping).values,
    edges_ind['target'].map(global_node_mapping).values
])
edge_index = torch.tensor(edge_index_np, dtype=torch.long)

## =================== GRAPH DATA OBJECT ==================
data = Data(x=all_node_features, edge_index=edge_index, edge_attr=edge_weight)

In [9]:
# Select the train node indices as the indices that are associated to normal nodes 
# Extract the subgraph induced by the normal nodes only

train_node_indices = [global_node_mapping[node_id] for node_id in normal_ids]

subset = torch.tensor(train_node_indices, dtype=torch.long)
edge_index_sub, edge_attr_sub = subgraph(subset, data.edge_index, edge_attr=data.edge_attr, relabel_nodes=True)

x_sub = data.x[subset]

train_data = Data(x=x_sub, edge_index=edge_index_sub, edge_attr=edge_attr_sub)

In [None]:
### Instantiate and train GNN-based OD model

In [10]:
from pygod.detector import GAE

model = GAE(gpu=0, hid_dim=16, num_layers=2, batch_size=4096, num_neigh=[5,5], epoch=100)

In [None]:
# Train model
model.fit(train_data)

In [None]:
# Infer on entire graph
model_predictions, model_scores = model.predict(data, return_score=True)

In [13]:
# Save model scores 
model_scores = model_scores.numpy() if torch.is_tensor(model_scores) else model_scores

results = pd.DataFrame({
    'id': nodes['id'],
    'score': model_scores
})

results_location = f"../results/synthetic/{DATASET}_GAE_100_epochs.csv"
results.to_csv(results_location, index=False)