### Init Context

In [1]:
from thetaray.api.context import init_context
from datetime import datetime
import yaml

import logging
logging.basicConfig(level=logging.DEBUG, format='%(message)s')

with open('/thetaray/git/solutions/domains/demo_digital_wallets/config/spark_config.yaml') as spark_config_file:
    spark_config = yaml.load(spark_config_file, yaml.FullLoader)['spark_config_a']
    
# Inicializar contexto
context = init_context(
    execution_date=datetime(1970, 2, 1),
    #spark_conf=spark_config,
    spark_conf=spark_config, # quitar
    # spark_master='local[*]', # quitar
)

2025-08-07 14:27:35,656:INFO:thetaray.common.logging:start loading solution.....[ load_risks=True , solution_path=/thetaray/git/solutions/domains , settings_path=/thetaray/git/solutions/settings ]
2025-08-07 14:27:36,050:INFO:thetaray.common.logging:load_risks took: 0.19347643852233887
2025-08-07 14:27:36,680:INFO:thetaray.common.logging:=== Started updating schema ===


### Imports

In [2]:
from thetaray.api.context import init_context
from thetaray.api.dataset import dataset_functions
from thetaray.api.evaluation import load_evaluated_activities, read_alerted_activities
from thetaray.api.graph import publish_edges, publish_nodes

from domains.demo_digital_wallets.datasets.customers import customers_dataset
from domains.demo_digital_wallets.datasets.transactions import transactions_dataset
from domains.demo_digital_wallets.evaluation_flows.ef import evaluation_flow

from datetime import datetime

from thetaray.common.data_environment import DataEnvironment

import pandas as pd
from pyspark.sql import functions as f
from pyspark.sql.types import *

### Nodes

In [None]:
from pyspark.sql import functions as f
from pyspark.sql import Window
from thetaray.api.dataset import dataset_functions
from thetaray.common.data_environment import DataEnvironment

# Customers -> AC
customer_df = dataset_functions.read(
    context,
    customers_dataset().identifier
)

nodes_ac_customers = (
    customer_df.select(
        f.col("client_id").alias("AN"),  # clave del nodo AC
        f.col("client_name").alias("NM"),
        f.col("country_of_residence_code").alias("CN"),
        f.col("address").alias("AD"),
    )
    .withColumn("id", f.col("AN"))  # id técnico = AN
    .withColumn("effective_date", f.lit(context.execution_date))
    .select("id", "CN", "NM", "AD", "AN", "effective_date")
)

# Counterparties observadas en transacciones -> AC
trx_df_for_nodes = dataset_functions.read(
    context,
    transactions_dataset().identifier
).where(f.col("counterparty_id").isNotNull())

cp_country = (
    f.when(f.col("direction") == f.lit("outflow"), f.col("country_destination"))
     .when(f.col("direction") == f.lit("inflow"),  f.col("country_origin"))
     .otherwise(f.coalesce(f.col("country_destination"), f.col("country_origin")))
)

cp_raw = trx_df_for_nodes.select(
    f.col("counterparty_id").alias("AN"),
    f.col("counterparty_name").alias("NM"),
    cp_country.alias("CN"),
    f.lit("").alias("AD"),
    f.col("transaction_datetime").alias("_ts")
)

w_cp = Window.partitionBy("AN").orderBy(f.col("_ts").desc_nulls_last())
nodes_ac_cp = (
    cp_raw
    .withColumn("_rn", f.row_number().over(w_cp))
    .filter(f.col("_rn") == 1)
    .drop("_rn", "_ts")
    .withColumn("id", f.col("AN"))
    .withColumn("effective_date", f.lit(context.execution_date))
    .select("id", "CN", "NM", "AD", "AN", "effective_date")
)

nodes_ac_df = nodes_ac_customers.unionByName(nodes_ac_cp)
nodes_ac_df = nodes_ac_df.fillna({"NM":"", "CN":"", "AD":""})
    
publish_nodes(
    context,
    nodes_ac_df,
    "demo_dwallets_graph",
    "AC")

### Edges

In [None]:
trx_df = dataset_functions.read(
    context,
    transactions_dataset().identifier
).where(f.col("counterparty_id").isNotNull())

base_tx = trx_df.select(
    f.col("transaction_id").alias("id"),
    f.col("transaction_datetime").alias("TS"),
    f.col("client_id").alias("client_id"),
    f.col("counterparty_id").alias("counterparty_id"),
    f.col("amount").cast("double").alias("AM"),            # ← asegurar numérico
    f.lit(1).cast("double").alias("CT"),  # ← string
    f.col("country_origin").alias("CO"),
    f.col("country_destination").alias("CD"),
    f.col("direction").alias("direction")
)

# IN: source=counterparty -> target=client
incoming_edges_df = (
    base_tx.where(f.col("direction") == f.lit("inflow"))
           .select(
               f.col("id"),
               f.col("TS").alias("effective_date"),
               f.col("counterparty_id").alias("source_node"),
               f.col("client_id").alias("target_node"),
               "AM","CT","CO","CD","TS"
           )
)

# OUT: source=client -> target=counterparty
outgoing_edges_df = (
    base_tx.where(f.col("direction") == f.lit("outflow"))
           .select(
               f.col("id"),
               f.col("TS").alias("effective_date"),
               f.col("client_id").alias("source_node"),
               f.col("counterparty_id").alias("target_node"),
               "AM","CT","CO","CD","TS"
           )
)

edges_tx_df = incoming_edges_df.unionByName(outgoing_edges_df).select(
    "id","effective_date","source_node","target_node","AM","CT","CO","CD"
)

publish_edges(
    context,
    edges_tx_df,
    "demo_dwallets_graph",
    "TX",
    "AC",
    "AC")


### Read alerted activities

In [None]:
from pyspark.sql import functions as f
from thetaray.api.evaluation import load_evaluated_activities, read_alerted_activities
from thetaray.common.data_environment import DataEnvironment

def first_match(df, candidates):
    for c in candidates:
        if c in df.columns:
            return c
    return None

act_df = read_alerted_activities(context, evaluation_flow().identifier)
eval_act_df = load_evaluated_activities(context, evaluation_flow().identifier)

tr_act  = first_match(act_df, ["tr_id", "transaction_id"])
tr_eval = first_match(eval_act_df, ["tr_id", "transaction_id"])
if tr_act is None or tr_eval is None:
    raise RuntimeError("No encuentro columnas de transacción (tr_id/transaction_id).")

act_df = act_df.withColumnRenamed(tr_act, "tr_id_norm")
eval_act_df = eval_act_df.withColumnRenamed(tr_eval, "tr_id_norm")

joined_act_df = eval_act_df.join(act_df, "tr_id_norm", "inner")

risk_col = first_match(joined_act_df, ["risk_id", "risk", "risk_code"])
if risk_col is None:
    raise RuntimeError("No encuentro columna de riesgo.")

suppr_cols = [c for c in ["is_suppressed", "suppressed", "is_filtered"] if c in joined_act_df.columns]
is_supp_expr = f.coalesce(*[f.col(c) for c in suppr_cols]) if suppr_cols else f.lit(False).cast("boolean")

cust_col = first_match(joined_act_df, ["client_id", "client_id"])
if cust_col is None:
    raise RuntimeError("No encuentro columna de cliente (client_id/client_id).")

ym_col = first_match(joined_act_df, ["year_month", "ym", "yearmonth"])
if ym_col is None:
    ts_col = first_match(joined_act_df, [
        "transaction_datetime", "transaction_timestamp", "effective_date",
        "event_time", "activity_ts", "tr_timestamp"
    ])
    if ts_col is None:
        raise RuntimeError("No encuentro year_month ni timestamp para derivarlo.")
    year_month_expr = f.date_format(f.col(ts_col), "yyyyMM").alias("year_month")
else:
    year_month_expr = f.col(ym_col).alias("year_month")

selected_activity_fields = (
    joined_act_df.select(
        f.col("tr_id_norm").alias("tr_id"),
        f.col(risk_col).alias("risk_id"),
        year_month_expr.alias("year_month"),
        is_supp_expr.alias("is_suppressed"),
        f.col(cust_col).alias("client_id"),
    )
).cache()


### Extract and publish alert nodes

In [18]:
from pyspark.sql import functions as f

sf = selected_activity_fields

al_nodes_df = (
    sf
    .withColumn("AI", f.col("tr_id").cast("string"))
    .withColumn("SP", f.col("is_suppressed").cast("boolean"))
    .withColumn("RI", f.col("risk_id").cast("string"))
    .withColumn(
        "effective_date",
        f.when(
            f.col("year_month").cast("string").rlike(r"^[0-9]{6}$"),
            f.to_timestamp(f.concat(f.col("year_month").cast("string"), f.lit("01")), "yyyyMMdd")
        ).otherwise(
            f.to_timestamp(f.col("year_month"))
        )
    )
    .withColumn("id", f.concat_ws("_", f.col("AI"), f.col("RI"))) \
    .select("id","effective_date","AI","SP","RI")
)


In [20]:
publish_nodes(context, al_nodes_df, "demo_dwallets_graph", "AL")

2025-08-07 14:42:39,308:INFO:thetaray.common.logging:Truncating data by execution date: tr_job_ts = '1970-02-01 00:00:00' AND type IN ('AL')


{'node_count': 0}

### Extract and publish alert - account edges

In [21]:
from pyspark.sql import functions as f

sf = selected_activity_fields

al_edges_df = (
    sf
    .withColumn("AI", f.col("tr_id").cast("string"))
    .withColumn(
        "effective_date",
        f.when(
            f.col("year_month").cast("string").rlike(r"^[0-9]{6}$"),
            f.to_timestamp(f.concat(f.col("year_month").cast("string"), f.lit("01")), "yyyyMMdd")
        ).otherwise(
            f.to_timestamp(f.col("year_month"))
        )
    )
    .withColumn("source_node", f.concat_ws("_", f.col("AI"), f.col("risk_id")))
    .withColumn("target_node", f.col("client_id").cast("string"))
    .withColumn("id", f.concat_ws("_", f.col("source_node"), f.col("target_node")))
    .select("id", "effective_date", "source_node", "target_node")
    .dropDuplicates(["id"])
)


In [23]:
publish_edges(
    context,
    al_edges_df,
    "demo_dwallets_graph",
    edge_type="AL",
    source_node_type="AL",
    target_node_type="AC")

2025-08-07 14:42:48,570:INFO:thetaray.common.logging:Truncating data by execution date: tr_job_ts = '1970-02-01 00:00:00' AND type IN ('AL')
                                                                                

{'edges_count': 0, 'unknown_nodes_count': 0}

In [24]:
selected_activity_fields.unpersist()

DataFrame[tr_id: string, risk_id: string, year_month: timestamp, is_suppressed: boolean, customer_id: string]

In [25]:
context.close()