In [None]:
import json
import os
import shutil
from datetime import datetime, timedelta
import pandas as pd
from pyspark.sql import functions as sf
from pyspark import SparkConf
from pyspark.sql import SparkSession

In [None]:
config = [
    ("spark.driver.memory", "10g"), 
    ("spark.jars.packages", "graphframes:graphframes:0.8.2-spark3.2-s_2.12"),
    ("spark.executor.memory", "32g"),
    ("spark.driver.memory", "64g"),
    ("spark.driver.maxResultSize", "64g"),
    ("spark.sql.session.timeZone", "UTC")
]
spark = SparkSession.builder.appName("temporal_graph").config(conf=SparkConf().setAll(config)).getOrCreate()

In [None]:
with open("../config.json") as f:
    config = json.load(f)

DATASET = config["dataset"]
WINDOW = config["temporal_graph"]['window']

transactions = pd.read_parquet(f"../data/01-ibm-transactions-for-aml/preprocessed/{DATASET}-transactions")
transactions.rename(columns={"transaction_id": "id"}, inplace=True)

In [None]:
transactions = transactions.sample(frac=0.1, random_state=42)

In [None]:
DATA_FOLDER = "../data/01-ibm-transactions-for-aml/temporal_graph"
location_staging = os.path.join(DATA_FOLDER, f"{DATASET}_staging")

transactions["transaction_timestamp"] = pd.to_datetime(transactions["timestamp"])
transactions["transaction_date"] = transactions["transaction_timestamp"].dt.date
transactions["transaction_timestamp"] = transactions["transaction_timestamp"].astype(int) // 10**9
del transactions["timestamp"]

transactions.to_parquet(location_staging)

location_transactions = os.path.join(DATA_FOLDER, f"{DATASET}_transactions")
staged = spark.read.parquet(location_staging)
(
    staged.repartition("transaction_date")
    .write.partitionBy("transaction_date")
    .mode("overwrite")
    .parquet(location_transactions)
)

In [None]:
data = spark.read.parquet(location_transactions)

In [None]:
data = data.withColumn("amount", sf.ceil("amount").cast("long"))
min_timestamp = data.select(sf.min("transaction_timestamp")).collect()[0][0]
data = data.withColumn("transaction_timestamp", sf.col("transaction_timestamp") - min_timestamp)

In [None]:
data = data.orderBy("transaction_timestamp", "transaction_date")

In [None]:
def rename_columns(dataframe, names):
    for name, new_name in names.items():
        dataframe = dataframe.withColumnRenamed(name, new_name)
    return dataframe

def max_timestamp(dt):
    year, month, date = dt.split("-")
    return (datetime(int(year), int(month), int(date)) + timedelta(days=1)).timestamp()

In [None]:
location_joins = os.path.join(DATA_FOLDER, f"{DATASET}_joins")
shutil.rmtree(location_joins, ignore_errors=True)

left_columns = {x: f"{x}_left" for x in data.columns}
dates = data.select("transaction_date").distinct().toPandas()
dates = sorted([str(x) for x in dates["transaction_date"].tolist()])
for transaction_date in dates:
    print(transaction_date)
    start_index = dates.index(transaction_date)
    end_index = start_index + WINDOW + 1
    right_dates = dates[start_index:end_index]
    end_date_max = str(pd.to_datetime(transaction_date).date() + timedelta(days=WINDOW))
    right_dates = [x for x in right_dates if x <= end_date_max]
    right = spark.read.option("basePath", location_transactions).parquet(
        *[f"{location_transactions}{os.sep}transaction_date={x}" for x in right_dates]
    )
    left = rename_columns(right.where(right.transaction_timestamp < max_timestamp(transaction_date)), left_columns)
    join = left.join(right, left.target_left == right.source, "inner")
    join = join.withColumn("delta", join.transaction_timestamp - join.transaction_timestamp_left)
    join = join.where(join.delta > 0)
    join.write.parquet(f"{location_joins}/staging_date={transaction_date}", mode="overwrite")

In [None]:
joins = spark.read.parquet(location_joins)

In [None]:
nodes_location = os.path.join(DATA_FOLDER, f"{DATASET}_nodes")
edges_location = os.path.join(DATA_FOLDER, f"{DATASET}_edges")

In [None]:
node_columns = [
    "id",
    "source",
    "target",
    "transaction_date",
    "transaction_timestamp",
    "amount",
]
nodes = data.select(*node_columns).drop_duplicates(subset=["id"])

edges = joins.select(
    sf.col("id_left").alias("src"),
    sf.col("id").alias("dst"),
    sf.col("transaction_date_left").alias("src_date"),
    sf.col("transaction_date").alias("dst_date"),
    (sf.when(
        sf.col("amount_left") > sf.col("amount"), sf.col("amount") / sf.col("amount_left")
    ).otherwise(sf.col("amount_left") / sf.col("amount"))).alias("weight")
)

nodes = nodes.repartition("transaction_date")
nodes.write.partitionBy("transaction_date").mode("overwrite").parquet(nodes_location)

partition_by = ["src_date", "dst_date"]
edges.repartition(*partition_by).write.partitionBy(*partition_by).mode("overwrite").parquet(edges_location)

nodes = spark.read.parquet(nodes_location)
edges = spark.read.parquet(edges_location)

In [None]:
nodes.show(10)

In [None]:
edges.show(10)

In [None]:
print("# of nodes", nodes.count())
print("# of edges", edges.count())
spark.stop()