### Init Context

In [None]:
from thetaray.api.context import init_context
import datetime
import yaml

import logging
logging.basicConfig(level=logging.DEBUG, format='%(message)s')

with open('/thetaray/git/solutions/domains/demo_ret_smb/config/spark_config.yaml') as spark_config_file:
    spark_config = yaml.load(spark_config_file, yaml.FullLoader)['spark_config_a']
context = init_context(execution_date=datetime.datetime(1970, 2, 1),
                       spark_conf=spark_config,
                       delete_unused_columns=True,
                       spark_master='local[*]')

### Imports

In [None]:
from thetaray.api.context import init_context
from thetaray.api.dataset import dataset_functions
from thetaray.api.evaluation import load_evaluated_activities, read_alerted_activities
from thetaray.api.graph import publish_edges, publish_nodes

from domains.demo_ret_smb.datasets.customers import customers_dataset
from domains.demo_ret_smb.datasets.transactions import transactions_dataset
from domains.demo_ret_smb.evaluation_flows.ef import evaluation_flow

from datetime import datetime

import pandas as pd
from pyspark.sql import functions as f
from pyspark.sql.types import *

### Nodes

In [None]:
customer_df = dataset_functions.read(context, customers_dataset().identifier)

nodes_df = (
    customer_df.select("customer_id", "incorporation_country", "business_name")
    .withColumnRenamed("customer_id", "id")
    .withColumnRenamed("incorporation_country", "CT")
    .withColumnRenamed("business_name", "NM")
    .withColumn("AN", f.col("id"))
    .withColumn("AD", f.lit(""))
    .withColumn("effective_date", f.lit(context.execution_date))
).select("id", "CT", "NM","AD", "AN","effective_date")

trx_df = dataset_functions.read(context, transactions_dataset().identifier)
trx_df = trx_df.where(f.col("counterparty_account").isNotNull())

cp_nodes_df = (
    trx_df.select("counterparty_account", "transaction_timestamp", "counterparty_country", "counterparty_customer_name")
    .withColumnRenamed("counterparty_account", "id")
    .withColumnRenamed("counterparty_country", "CT")
    .withColumnRenamed("counterparty_customer_name", "NM")
    .withColumn("AD", f.lit(""))
    .withColumn("AN", f.col("id"))
    .withColumn("effective_date", f.lit(context.execution_date))
).select("id", "CT", "NM", "AD", "AN", "effective_date")

cp_nodes_df = cp_nodes_df.dropDuplicates(subset=["id"])

nodes_df = nodes_df.union(cp_nodes_df)

publish_nodes(context, nodes_df, "demo_ret_smb_graph", "AC")

### Edges

In [None]:
trx_df = dataset_functions.read(context, transactions_dataset().identifier)
trx_df = trx_df.where(f.col("counterparty_account").isNotNull())

incoming_df = trx_df.where(f.col('in_out') == "IN")
outgoing_df = trx_df.where(f.col('in_out') == "OUT")

incoming_edges_df = (
    incoming_df.select("transaction_id", "transaction_timestamp", "original_trx_amount", "customer_id", "counterparty_account", "original_trx_currency")
    .withColumnRenamed("transaction_id", "id")
    .withColumnRenamed("transaction_timestamp", "effective_date")
    .withColumnRenamed("counterparty_account", "source_node")
    .withColumnRenamed("customer_id", "target_node")
    .withColumnRenamed("original_trx_amount", "AM")
    .withColumnRenamed("original_trx_currency", "CR")
    .withColumn("count", f.lit(1))
)
incoming_edges_df = incoming_edges_df.withColumn("CT", incoming_edges_df["count"].cast("long"))
incoming_edges_df = incoming_edges_df.select('id', 'effective_date', 'source_node', 'target_node', 'AM', 'CR', 'CT')

print(incoming_edges_df.count())

outgoing_edges_df = (
    outgoing_df.select("transaction_id", "transaction_timestamp", "original_trx_amount", "customer_id", "counterparty_account", "original_trx_currency")
    .withColumnRenamed("transaction_id", "id")
    .withColumnRenamed("transaction_timestamp", "effective_date")
    .withColumnRenamed("customer_id", "source_node")
    .withColumnRenamed("counterparty_account", "target_node")
    .withColumnRenamed("original_trx_amount", "AM")
    .withColumnRenamed("original_trx_currency", "CR")
    .withColumn("count", f.lit(1))
)
outgoing_edges_df = outgoing_edges_df.withColumn("CT", outgoing_edges_df["count"].cast("long"))
outgoing_edges_df = outgoing_edges_df.select('id', 'effective_date', 'source_node', 'target_node', 'AM', 'CR', 'CT')

print(outgoing_edges_df.count())

edges_df = incoming_edges_df.union(outgoing_edges_df)

publish_edges(context, edges_df, "demo_ret_smb_graph", "TX", "AC", "AC")

### Read alerted activities

In [None]:
act_df = read_alerted_activities(context, evaluation_flow().identifier)
eval_act_df = load_evaluated_activities(context, evaluation_flow().identifier)
joined_act_df = eval_act_df.join(act_df, "tr_id")
selected_activity_fields = joined_act_df.select("tr_id", "risk_id", "year_month", "is_suppressed", "customer_id")
selected_activity_fields.cache()

### Extract and publish alert nodes

In [None]:
al_nodes_df = (
    selected_activity_fields.withColumn("id", f.concat(f.col("tr_id"), f.lit("_"), f.col("risk_id")))
    .withColumnRenamed("year_month", "effective_date")
    .withColumnRenamed("is_suppressed", "SP")
    .withColumnRenamed("risk_id", "RI")
    .withColumnRenamed("tr_id", "AI")
    .drop("customer_id")
)

In [None]:
publish_nodes(context, al_nodes_df, "demo_ret_smb_graph", "AL")

### Extract and publish alert - account edges

In [None]:
al_edges_df = (
    selected_activity_fields.withColumn("id", f.concat(f.col("tr_id"), f.lit("_"), f.col("risk_id")))
    .withColumnRenamed("year_month", "effective_date")
    .withColumn("source_node", f.col("id"))
    .withColumnRenamed("customer_id", "target_node")
    .drop("is_suppressed", "tr_id", "risk_id")
)

In [None]:
publish_edges(
    context,
    al_edges_df,
    "demo_ret_smb_graph",
    edge_type="AL",
    source_node_type="AL",
    target_node_type="AC",
)

In [None]:
selected_activity_fields.unpersist()

In [None]:
context.close()