## Temporal Graph of Sequential Transactions
*Code adapted from https://github.com/mhaseebtariq/fastman/tree/main*

In [None]:
import os
import gc
import shutil
import pandas as pd
from datetime import datetime, timedelta
from pyspark.sql import functions as sf
from pyspark import SparkConf
from pyspark.sql import SparkSession

# Instantiate PySpark session
config = [
    ("spark.driver.memory", "128g"), 
    ("spark.executor.memory", "128g"),
    ("spark.driver.memory", "128g"),
    ("spark.driver.maxResultSize", "128g"),
    ("spark.sql.session.timeZone", "UTC")
]
spark = SparkSession.builder.appName("06_temporal_graph_creation").config(conf=SparkConf().setAll(config)).getOrCreate()

In [None]:
DATASET = "HI-Small"    ## either HI-Small or LI-Small
WINDOW = 5              ## time window parameter for connecting transactions in sequence

input_path = f"../datasets/synthetic/02_preprocessed/{DATASET}-transactions.parquet"
DATA_FOLDER = "../datasets/synthetic/05_temporal_graph"

In [None]:
# Read preprocessed transactions with at least ['transaction_id', 'source', 'target', 'timestamp', 'amount'] 
transactions = pd.read_parquet(input_path)
transactions.rename(columns={"transaction_id": "id"}, inplace=True)

In [None]:
location_staging = os.path.join(DATA_FOLDER, f"{DATASET}_staging")

transactions["transaction_timestamp"] = pd.to_datetime(transactions["timestamp"])
transactions["transaction_date"] = transactions["transaction_timestamp"].dt.date
transactions["transaction_timestamp"] = transactions["transaction_timestamp"].astype(int) // 10**9
del transactions["timestamp"]

transactions.to_parquet(location_staging)

location_transactions = os.path.join(DATA_FOLDER, f"{DATASET}_transactions")
staged = spark.read.parquet(location_staging)
(
    staged.repartition("transaction_date")
    .write.partitionBy("transaction_date")
    .mode("overwrite")
    .parquet(location_transactions)
)

In [None]:
data = spark.read.parquet(location_transactions)

In [None]:
data = data.withColumn("amount", sf.ceil("amount").cast("long"))
min_timestamp = data.select(sf.min("transaction_timestamp")).collect()[0][0]
data = data.withColumn("transaction_timestamp", sf.col("transaction_timestamp") - min_timestamp)

In [None]:
data = data.orderBy("transaction_timestamp", "transaction_date")

In [None]:
def rename_columns(dataframe, names):
    for name, new_name in names.items():
        dataframe = dataframe.withColumnRenamed(name, new_name)
    return dataframe

def max_timestamp(dt):
    year, month, date = dt.split("-")
    return (datetime(int(year), int(month), int(date)) + timedelta(days=1)).timestamp()

In [None]:
# Prepare the joins of sequential transactions
left_columns = {x: f"{x}_left" for x in data.columns}
dates = data.select("transaction_date").distinct().toPandas()
dates = sorted([str(x) for x in dates["transaction_date"].tolist()])

location_joins = os.path.join(DATA_FOLDER, f"{DATASET}_joins")
shutil.rmtree(location_joins, ignore_errors=True)

In [None]:
# Perform a flow join between transactions where a link exist between two nodes
# if the target attribute of the first node matches the source of another, within
# the specified time window.
for transaction_date in dates:
    print(transaction_date)
    start_index = dates.index(transaction_date)
    end_index = start_index + WINDOW + 1
    right_dates = dates[start_index:end_index]
    end_date_max = str(pd.to_datetime(transaction_date).date() + timedelta(days=WINDOW))
    right_dates = [x for x in right_dates if x <= end_date_max]
    right = spark.read.option("basePath", location_transactions).parquet(
        *[f"{location_transactions}{os.sep}transaction_date={x}" for x in right_dates]
    )
    left = rename_columns(right.where(right.transaction_timestamp < max_timestamp(transaction_date)), left_columns)
    flow_join = left.join(right, left.target_left == right.source, "inner")
    flow_join = flow_join.withColumn("delta", flow_join.transaction_timestamp - flow_join.transaction_timestamp_left)
    flow_join = flow_join.where(flow_join.delta > 0)
    flow_join.write.parquet(f"{location_joins}/type=flow/staging_date={transaction_date}", mode="overwrite")

In [None]:
# Perform a fan_in join between transactions where a link exist between two nodes
# if the target attribute of the first node matches the target of another, within
# the specified time window.
for transaction_date in dates:
    print(transaction_date)
    start_index = dates.index(transaction_date)
    end_index = start_index + WINDOW + 1
    right_dates = dates[start_index:end_index]
    end_date_max = str(pd.to_datetime(transaction_date).date() + timedelta(days=WINDOW))
    right_dates = [x for x in right_dates if x <= end_date_max]
    right = spark.read.option("basePath", location_transactions).parquet(
        *[f"{location_transactions}{os.sep}transaction_date={x}" for x in right_dates]
    )
    left = rename_columns(right.where(right.transaction_timestamp < max_timestamp(transaction_date)), left_columns)
    f_in_join = left.join(right, left.target_left == right.target, "inner")
    f_in_join = f_in_join.withColumn("delta", f_in_join.transaction_timestamp - f_in_join.transaction_timestamp_left)
    f_in_join = f_in_join.where(f_in_join.delta > 0)
    f_in_join.write.parquet(f"{location_joins}/type=fan_in/staging_date={transaction_date}", mode="overwrite")

In [None]:
# Perform a fan_out join between transactions where a link exist between two nodes
# if the source attribute of the first node matches the source of another, within
# the specified time window.
for transaction_date in dates:
    print(transaction_date)
    start_index = dates.index(transaction_date)
    end_index = start_index + WINDOW + 1
    right_dates = dates[start_index:end_index]
    end_date_max = str(pd.to_datetime(transaction_date).date() + timedelta(days=WINDOW))
    right_dates = [x for x in right_dates if x <= end_date_max]
    right = spark.read.option("basePath", location_transactions).parquet(
        *[f"{location_transactions}{os.sep}transaction_date={x}" for x in right_dates]
    )
    left = rename_columns(right.where(right.transaction_timestamp < max_timestamp(transaction_date)), left_columns)
    f_out_join = left.join(right, left.source_left == right.source, "inner")
    f_out_join = f_out_join.withColumn("delta", f_out_join.transaction_timestamp - f_out_join.transaction_timestamp_left)
    f_out_join = f_out_join.where(f_out_join.delta > 0)
    f_out_join.write.parquet(f"{location_joins}/type=fan_out/staging_date={transaction_date}", mode="overwrite")

In [None]:
gc.collect()

In [None]:
location_joins = os.path.join(DATA_FOLDER, f"{DATASET}_joins")
joins = spark.read.parquet(location_joins)

In [None]:
nodes_location = os.path.join(DATA_FOLDER, f"{DATASET}_nodes")
edges_location = os.path.join(DATA_FOLDER, f"{DATASET}_edges")

In [None]:
# Select the attributes for nodes (transactions) and write nodes
node_columns = [
    "id",
    "source",
    "target",
    "transaction_date",
    "transaction_timestamp",
    "amount",
]
nodes = data.select(*node_columns).drop_duplicates(subset=["id"])

nodes = nodes.repartition("transaction_date")
nodes.write.partitionBy("transaction_date").mode("overwrite").parquet(nodes_location)

In [None]:
gc.collect()

In [None]:
# Extract edges from the joins, rename the columns and compute weights
edges = joins.select(
    sf.col("id_left").alias("src"),
    sf.col("id").alias("dst"),
    sf.col("transaction_date_left").alias("src_date"),
    sf.col("transaction_date").alias("dst_date"),
    sf.round(
        sf.when(
            sf.col("type").isin("fan_in", "fan_out"), 1
        ).otherwise(
            sf.when(
                sf.col("amount_left") > sf.col("amount"),
                sf.col("amount") / sf.col("amount_left")
            ).otherwise(
                sf.col("amount_left") / sf.col("amount")
            )
        ), 6
    ).alias("weight"),
    sf.col("type").alias("edge_type"),
    sf.col("delta")
)

partition_by = ["src_date", "dst_date"]
edges.repartition(*partition_by).write.partitionBy(*partition_by).mode("overwrite").parquet(edges_location)

In [None]:
# nodes = spark.read.parquet(nodes_location)
# edges = spark.read.parquet(edges_location)
# print("# of nodes", nodes.count())
# print("# of edges", edges.count())

In [None]:
spark.stop()