In [None]:
from datetime import datetime

import pandas as pd
from pyspark.sql import functions as f
from pyspark.sql.types import *


from thetaray.api.context import init_context
from thetaray.api.dataset import dataset_functions
from thetaray.api.evaluation import load_evaluated_activities, read_alerted_activities
from thetaray.api.graph import publish_edges, publish_nodes

In [None]:
context = init_context(execution_date=datetime(1970, 1, 1))
spark=context.get_spark_session()

### Read transactions

In [None]:
trans = dataset_functions.read(context, "transaction")

In [None]:
trans_additional_pd = pd.read_csv("/thetaray/data/bank/trans_graph_data.csv")

In [None]:
trans_schema = StructType(
    [
        StructField("trans_id", StringType(), True),
        StructField("dest_account_id", StringType(), True),
        StructField("currency", StringType(), True),
    ],
)
trans_additional_df = spark.createDataFrame(trans_additional_pd, trans_schema)

In [None]:
trans_full = trans.join(trans_additional_df, "trans_id")
trans_full.show()

### Extract and publish transaction - account edges

In [None]:
edges_df = (
    trans_full.select("trans_id", "date", "amount", "account_id", "dest_account_id", "currency")
    .withColumn("trans_id", f.col('trans_id').cast('string'))
    .withColumnRenamed("trans_id", "id")
    .withColumnRenamed("date", "effective_date")
    .withColumnRenamed("account_id", "source_node")
    .withColumnRenamed("dest_account_id", "target_node")
    .withColumnRenamed("amount", "AM")
    .withColumnRenamed("currency", "CR")
    .withColumn("count", f.lit(1))
)

edges_df = edges_df.withColumn("CT", edges_df["count"].cast("long"))

In [None]:
publish_edges(context, edges_df, "public", "TX", "AC", "AC")

### Read alerted activities

In [None]:
act_df = read_alerted_activities(context, "tr_analysis")

In [None]:
eval_act_df = load_evaluated_activities(context, "tr_analysis", "")

In [None]:
joined_act_df = eval_act_df.join(act_df, "tr_id")

In [None]:
selected_activity_fields = joined_act_df.select("tr_id", "risk_id", "date_loan", "is_suppressed", "account_id")
selected_activity_fields.cache()

### Extract and publish alert nodes

In [None]:
al_nodes_df = (
    selected_activity_fields.withColumn("id", f.concat(f.col("tr_id"), f.lit("_"), f.col("risk_id")))
    .withColumnRenamed("date_loan", "effective_date")
    .withColumnRenamed("is_suppressed", "SP")
    .withColumnRenamed("risk_id", "RI")
    .withColumnRenamed("tr_id", "AI")
    .drop("account_id")
)

In [None]:
publish_nodes(context, al_nodes_df, "public", "AL")

### Extract and publish alert - account edges

In [None]:
al_edges_df = (
    selected_activity_fields.withColumn("id", f.concat(f.col("tr_id"), f.lit("_"), f.col("risk_id")))
    .withColumnRenamed("date_loan", "effective_date")
    .withColumn("source_node", f.col("id"))
    .withColumnRenamed("account_id", "target_node")
    .withColumnRenamed("date_loan", "effective_date")
    .drop("is_suppressed", "tr_id", "risk_id")
)

In [None]:
publish_edges(
    context,
    al_edges_df,
    "public",
    edge_type="AL",
    source_node_type="AL",
    target_node_type="AC",
)

In [None]:
selected_activity_fields.unpersist()

In [None]:
context.close()