### Init Context

In [2]:
from thetaray.api.context import init_context
from datetime import datetime
from datetime import timedelta
import yaml

import logging
logging.basicConfig(level=logging.DEBUG, format='%(message)s')

with open('/thetaray/git/solutions/domains/demo_digital_wallets/config/spark_config.yaml') as spark_config_file:
    spark_config = yaml.load(spark_config_file, yaml.FullLoader)['spark_config_a']
    
context = init_context(
    execution_date=datetime(1970, 2, 1),
    #spark_conf=spark_config,
    spark_conf=spark_config, # quitar
    # spark_master='local[*]', # quitar
)

2025-08-28 13:03:29,841:INFO:thetaray.common.logging:start loading solution.....[ load_risks=True , solution_path=/thetaray/git/solutions/domains , settings_path=/thetaray/git/solutions/settings ]
2025-08-28 13:03:30,282:INFO:thetaray.common.logging:load_risks took: 0.1303706169128418
2025-08-28 13:03:31,000:INFO:thetaray.common.logging:=== Started updating schema ===


### Imports

In [3]:
from thetaray.api.dataset import dataset_functions

from domains.demo_digital_wallets.datasets.customers import customers_dataset
from domains.demo_digital_wallets.datasets.transactions import transactions_dataset
from domains.demo_digital_wallets.datasets.customer_insights import customer_insights_dataset
from domains.demo_digital_wallets.evaluation_flows.ef import evaluation_flow

import json
import psycopg2
import os
import random
import pandas as pd

from datetime import datetime
from faker import Faker
from pyspark.sql import functions as f
from pyspark.sql.types import StructType

from thetaray.api.dataset.schema import DatasetSchemaHandler
from thetaray.common import Constants, Settings
from thetaray.common.data_environment import DataEnvironment

spark = context.get_spark_session()

ns_suffix = Settings.SHARED_NAMESPACE.removeprefix('shared-')

fake = Faker()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/08/28 13:03:35 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
25/08/28 13:03:36 WARN MetricsConfig: Cannot locate configuration: tried hadoop-metrics2-s3a-file-system.properties,hadoop-metrics2.properties
Hive Session ID = 4a8aa42f-b3e6-4276-80b7-8a38070b6115
25/08/28 13:03:40 INFO SessionState: Hive Session ID = 4a8aa42f-b3e6-4276-80b7-8a38070b6115


### Creation

In [4]:
DB_HOST = Settings.DB_HOST

DB_USER_CDD = os.environ['CDD_POSTGRES_USERNAME']
DB_PASS_CDD = os.environ['CDD_POSTGRES_PASSWORD']
DB_USER_RP = 'postgres'
DB_PASS_RP = 'postgres'


dsn_cdd = (
    f'user={DB_USER_CDD} '
    f'password={DB_PASS_CDD} '
    f'dbname={Constants.CDD_DB_NAME} '
    f'host={DB_HOST[:-5]} '
    f'port={DB_HOST[-4:]} '
    'sslmode=verify-ca '
    'sslrootcert=/certs/ca.crt'
)


dsn_rp = (
    f'user={DB_USER_RP} '
    f'password={DB_PASS_RP} '
    f'dbname={Constants.CDD_DB_NAME} '
    f'host={DB_HOST[:-5]} '
    f'port={DB_HOST[-4:]} '
    'sslmode=verify-ca '
    'sslrootcert=/certs/ca.crt'
)


def execute_query(query, dsn):
    conn = psycopg2.connect(dsn=dsn)
    with conn.cursor() as cursor:
        cursor.execute(query)
        columns = [col.name for col in cursor.description]
        rows = []
        for row in cursor.fetchall():
            rows.append({col: val for col, val in zip(columns, row)})
        return rows

def get_alert_mapper(solution, ef_id):
    schema = f'apps_{ns_suffix.replace("-", "_")}'
    for alert_mapper in execute_query(f'SELECT * FROM {schema}.rp_mappers', dsn_rp):
        ef_unit = json.loads(alert_mapper['solution_evaluation_flow_unit'])
        if not ef_unit:
            continue
        ef_unit = ef_unit[0]
        if ef_unit['solutionId'] == solution and ef_unit['evaluationFlowId'] == ef_id:
            return alert_mapper


def get_alerts(solution, ef_id):
    schema = f'apps_{ns_suffix.replace("-", "_")}'
    alert_mapper = get_alert_mapper(solution, ef_id)
    if alert_mapper is None:
        raise Exception(f'Alert mapper not found for {solution = } and {ef_id = }')
    alert_mapper_identifier = alert_mapper['identifier']
    alert_fields = execute_query(f'SELECT * FROM {schema}.rp_alert_fields', dsn_rp)
    alert_fields = {alert_field['rp_alert_id']: alert_field for alert_field in alert_fields}
    alerts = execute_query(f"SELECT * FROM {schema}.rp_alerts WHERE alert_mapper_identifier = '{alert_mapper_identifier}' AND history_type = 'CURRENT'", dsn_rp)
    for alert in alerts:
        alert['client_id'] = alert_fields[alert['alert_id']]['client_id']
    return alerts


def get_accounts(solution):
    schema = Constants.SOLUTION_SCHEMA_TPL.format(solution=solution)
    query = f"SELECT * FROM {schema}.demo_digital_wallets_customers"
    return execute_query(query, dsn_cdd)


def get_account_records(solution, client_id):
    schema = Constants.SOLUTION_SCHEMA_TPL.format(solution=solution)
    query = f"SELECT * FROM {schema}.demo_digital_wallets_customers WHERE client_id = '{client_id}'"
    return execute_query(query, dsn_cdd)


def get_account_transactions(solution, client_id):
    schema = Constants.SOLUTION_SCHEMA_TPL.format(solution=solution)
    query = f"SELECT * FROM {schema}.demo_digital_wallets_transactions WHERE client_id = '{client_id}'"
    return execute_query(query, dsn_cdd)

In [5]:
import json
import math
import random
from datetime import datetime
from typing import List, Dict, Any, Optional

def build_customer_insights_df(context, spark, dsn_cdd):
    
    solution = Settings.SOLUTION
    ef_id = evaluation_flow().identifier
    schema = Constants.SOLUTION_SCHEMA_TPL.format(solution=solution)

    alerts: List[Dict[str, Any]] = get_alerts(solution, ef_id)
    accounts: List[Dict[str, Any]] = get_accounts(solution)

    # Mapa client_id -> country_of_residence_code (y demás campos KYC base)
    customers_rows = execute_query(f"SELECT * FROM {schema}.demo_digital_wallets_customers", dsn_cdd)
    by_customer_row: Dict[str, Dict[str, Any]] = {r["client_id"]: r for r in customers_rows}
    account_country: Dict[str, Optional[str]] = {
        cid: row.get("country_of_residence_code") for cid, row in by_customer_row.items()
    }

    # (Opcional) Dataset agregado mensual con señales: si no existe, se ignora sin error.
    monthly_rows: List[Dict[str, Any]] = []
    try:
        monthly_rows = execute_query(
            f"SELECT * FROM {schema}.demo_digital_wallets_customer_monthly", dsn_cdd
        ) or []
    except Exception:
        monthly_rows = []
    # Nos quedamos con la última fila por cliente (mes más reciente) si hay historial
    latest_monthly_by_customer: Dict[str, Dict[str, Any]] = {}
    if monthly_rows:
        from datetime import datetime, date
    
        def _parse_any_ts(v) -> Optional[datetime]:
            """Intenta convertir v a datetime:
            - str ISO 'YYYY-MM' / 'YYYY-MM-DD' / 'YYYY-MM-DDTHH:MM:SS'
            - datetime o date
            - epoch (int/float) en segundos o milisegundos
            """
            if v is None:
                return None
            # datetime / date
            if isinstance(v, datetime):
                return v
            if isinstance(v, date):
                return datetime(v.year, v.month, v.day)
    
            # numérico: epoch
            if isinstance(v, (int, float)) and not isinstance(v, bool):
                # heurística: si es muy grande, probablemente ms
                try:
                    if v > 10_000_000_000:  # ~ 2001-09-09 en ms
                        return datetime.utcfromtimestamp(v / 1000.0)
                    return datetime.utcfromtimestamp(v)
                except Exception:
                    return None
    
            # string ISO
            if isinstance(v, str):
                s = v.strip()
                # 'YYYY-MM' -> asumir día 1
                if len(s) == 7 and s[4] == "-":
                    try:
                        y, m = map(int, s.split("-"))
                        return datetime(y, m, 1)
                    except Exception:
                        return None
                # intentos ISO más generales
                for fmt in ("%Y-%m-%d", "%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%dT%H:%M:%S.%f"):
                    try:
                        return datetime.strptime(s, fmt)
                    except Exception:
                        pass
                # último recurso: fromisoformat (puede fallar en algunos formatos)
                try:
                    return datetime.fromisoformat(s)
                except Exception:
                    return None
    
            return None
    
        def _key_safe(r: Dict[str, Any]) -> datetime:
            """
            Devuelve SIEMPRE un datetime comparable:
            - Prioriza: period, ds, effective_date, snapshot_ts, created_at, created_ts
            - Si nada parsea, usa datetime.min (garantiza orden total)
            """
            for k in ("period", "ds", "effective_date", "snapshot_ts", "created_at", "created_ts"):
                if k in r:
                    ts = _parse_any_ts(r.get(k))
                    if ts is not None:
                        return ts
            return datetime.min  # nunca None
    
        tmp: Dict[str, List[Dict[str, Any]]] = {}
        for r in monthly_rows:
            cid = r.get("client_id")
            if cid is None:
                # si aparece una fila sin client_id, la ignoramos para no romper
                continue
            tmp.setdefault(cid, []).append(r)
    
        for cid, rows in tmp.items():
            rows_sorted = sorted(rows, key=_key_safe)
            latest_monthly_by_customer[cid] = rows_sorted[-1] if rows_sorted else {}

    # ─────────────────────────────────────────────────────────────────────────────
    # Helpers
    # ─────────────────────────────────────────────────────────────────────────────
    def latest_record(recs: List[Dict[str, Any]]) -> Dict[str, Any]:
        """Último registro por tr_job_ts (si faltara, usa ahora)."""
        if not recs:
            return {}
        return sorted(recs, key=lambda x: x.get("tr_job_ts") or datetime.utcnow())[-1]

    def safe_sum_amount(trxs: List[Dict[str, Any]]) -> float:
        return float(sum(float(t.get("amount") or 0.0) for t in trxs))

    def risk_country_buckets(trxs: List[Dict[str, Any]]):
        """Devuelve (hr_cc, mr_cc, lr_cc) como listas de strings sin duplicados."""
        hr, mr, lr = set(), set(), set()
        for t in trxs:
            risk = t.get("counterparty_country_risk")
            cc = t.get("counterparty_country")
            if not cc:
                continue
            if risk == "High":
                hr.add(cc)
            elif risk == "Medium":
                mr.add(cc)
            elif risk == "Low":
                lr.add(cc)
        return list(hr), list(mr), list(lr)

    def get_min_max_ts(trxs: List[Dict[str, Any]]):
        """Min y max de transaction_timestamp (None si no hay)."""
        if not trxs:
            return None, None
        ts_vals = [t.get("transaction_timestamp") for t in trxs if t.get("transaction_timestamp") is not None]
        if not ts_vals:
            return None, None
        return min(ts_vals), max(ts_vals)

    def as_int_or_zero(x: Optional[int]) -> int:
        try:
            return int(x or 0)
        except Exception:
            return 0

    def avg_ticket(trxs: List[Dict[str, Any]]) -> float:
        if not trxs:
            return 0.0
        return safe_sum_amount(trxs) / max(len(trxs), 1)

    def infer_kyc_risk_level(
        has_high_risk_country_tx: bool,
        alerts_open: int,
        structuring_score: Optional[float] = None,
        rapid_load_immediate_spend: Optional[float] = None,
        crypto_usage_score: Optional[float] = None,
        many_to_one_score: Optional[float] = None,
        activity_spike_score: Optional[float] = None,
    ) -> str:
        """
        Heurística simple y transparente:
          - High si: high-risk country o alertas abiertas elevadas (>=3) o structuring/crypto/many_to_one/spike > 0.8
          - Medium si: señales moderadas (>=0.5) o alertas abiertas 1-2
          - Low en otro caso
        """
        def high(sig):
            return sig is not None and float(sig) >= 0.8
        def mid(sig):
            return sig is not None and float(sig) >= 0.5

        if has_high_risk_country_tx or alerts_open >= 3 or any(high(s) for s in [
            structuring_score, crypto_usage_score, many_to_one_score, activity_spike_score, rapid_load_immediate_spend
        ]):
            return "High"
        if alerts_open >= 1 or any(mid(s) for s in [
            structuring_score, crypto_usage_score, many_to_one_score, activity_spike_score, rapid_load_immediate_spend
        ]):
            return "Medium"
        return "Low"

    # ─────────────────────────────────────────────────────────────────────────────
    # Construcción del payload fila a fila
    # ─────────────────────────────────────────────────────────────────────────────
    customer_insights_data: List[Dict[str, Any]] = []

    for acc in accounts:
        client_id = acc["client_id"]

        # Registros y transacciones
        account_records = get_account_records(solution, client_id)
        acc_latest = latest_record(account_records)
        account_transactions = get_account_transactions(solution, client_id)

        trxs_in  = [t for t in account_transactions
            if (t.get("direction") or t.get("in_out") or "").strip().lower() == "inflow"]
        trxs_out = [t for t in account_transactions
            if (t.get("direction") or t.get("in_out") or "").strip().lower() == "outflow"]
        acc_alerts = [a for a in alerts if a.get("client_id") == client_id]

        # ───── KYC
        # Tomamos base de customers si existe; si no, de record/account
        cust_row = by_customer_row.get(client_id, {})
        kyc_name = acc_latest.get("name") or acc.get("client_name") or cust_row.get("client_name") or ""
        kyc_occupation = acc_latest.get("occupation") or acc.get("occupation") or cust_row.get("occupation") or ""
        kyc_is_new = len(account_records) == 1
        kyc_recently_updated = not kyc_is_new
        # Para demos: incorporarción reciente pseudoaleatoria estable por client_id
        rnd = random.Random(str(client_id))
        kyc_newly_incorporation = rnd.choice([True, False, False])
        kyc_new_customer = kyc_is_new
        kyc_classification = "Medium"  # placeholder nominal para UI

        # ───── Países de riesgo (listas)
        hr_cc, mr_cc, lr_cc = risk_country_buckets(account_transactions)
        has_high_risk_country_tx = len(hr_cc) > 0

        # ───── Direcciones (STRING serializado como JSON)
        # Director
        director_ad_dict = {
            "CC": acc_latest.get("country_of_residence_code")
                  or acc.get("country_of_residence_code")
                  or cust_row.get("country_of_residence_code"),
            "AD": acc_latest.get("address") or acc.get("address") or cust_row.get("address"),
            "CL": "L",
        }
        director_ad = json.dumps(director_ad_dict, ensure_ascii=False)

        # Compañía: si no hay datos, replicamos del director
        company_ad_dict = {
            "CC": acc_latest.get("company_country_code") or director_ad_dict["CC"],
            "AD": acc_latest.get("company_address") or director_ad_dict["AD"],
            "CL": "L",
        }
        company_ad = json.dumps(company_ad_dict, ensure_ascii=False)

        # ───── Estadísticas de transacción
        tr_in = safe_sum_amount(trxs_in)
        tr_out = safe_sum_amount(trxs_out)
        tr_in_count = len(trxs_in)
        tr_out_count = len(trxs_out)

        # Segmentaciones “demo” (mismo tipo DOUBLE/LONG, acotadas a los conteos)
        tr_in_seg = float(tr_in * rnd.uniform(0.4, 0.8))
        tr_out_seg = float(tr_out * rnd.uniform(0.4, 0.8))
        tr_in_seg_count = min(tr_in_count, int(tr_in_count * rnd.uniform(0.4, 0.8)))
        tr_out_seg_count = min(tr_out_count, int(tr_out_count * rnd.uniform(0.4, 0.8)))

        trx_from_date, trx_to_date = get_min_max_ts(account_transactions)

        # ───── Alertas (STRING serializado)
        open_alerts = len([a for a in acc_alerts if a.get("state_id") != "state_closed"])
        closed_alerts = len([a for a in acc_alerts if a.get("state_id") == "state_closed"])
        fp_alerts = len([a for a in acc_alerts if a.get("resolution_code") == "Non_Issue"])
        tm = json.dumps({"Open": open_alerts, "Closed": closed_alerts, "False_positives": fp_alerts}, ensure_ascii=False)
        scrn = json.dumps({"Open": 0, "Closed": 0, "False_positives": 0}, ensure_ascii=False)

        # ───── Derivados
        customer_country = account_country.get(client_id)
        # Fecha creación: 'created_at' o 'first_seen_ts'; si None, 0
        created_ts = acc_latest.get("created_at") or acc_latest.get("first_seen_ts")
        if created_ts and isinstance(created_ts, datetime):
            wallet_age_days = (context.execution_date - created_ts).days
        else:
            wallet_age_days = 0

        uses_crypto = any((str(t.get("channel")).upper() == "CRYPTO") for t in account_transactions)

        # ───── Señales de riesgo desde el mensual (si existe)
        monthly = latest_monthly_by_customer.get(client_id, {}) if latest_monthly_by_customer else {}
        structuring_score = monthly.get("structuring_score")
        rapid_load_immediate_spend = monthly.get("rapid_load_immediate_spend")
        crypto_usage_score = monthly.get("crypto_usage_score")
        many_to_one_score = monthly.get("many_to_one_score")
        activity_spike_score = monthly.get("activity_spike_score")

        # Nivel KYC derivado (heurística)
        kyc_risk_level = infer_kyc_risk_level(
            has_high_risk_country_tx=has_high_risk_country_tx,
            alerts_open=open_alerts,
            structuring_score=structuring_score,
            rapid_load_immediate_spend=rapid_load_immediate_spend,
            crypto_usage_score=crypto_usage_score,
            many_to_one_score=many_to_one_score,
            activity_spike_score=activity_spike_score,
        )

        # ───── Ensamblado final (coincide 1:1 con el schema esperado)
        row = {
            "client_id": client_id,
            "customer_name": kyc_name,
            "kyc_classification": kyc_classification,
            "kyc_name": kyc_name,
            "kyc_is_new": bool(kyc_is_new),
            "kyc_recently_updated": bool(kyc_recently_updated),
            "kyc_newly_incorporation": bool(kyc_newly_incorporation),
            "kyc_new_customer": bool(kyc_new_customer),
            "kyc_occupation": kyc_occupation,
            "hr_cc": hr_cc,                         # LIST[str]
            "mr_cc": mr_cc,                         # LIST[str]
            "lr_cc": lr_cc,                         # LIST[str]
            "director_ad": director_ad,             # STRING (JSON)
            "company_ad": company_ad,               # STRING (JSON)
            "tr_in": float(tr_in),                  # DOUBLE (EUR)
            "tr_out": float(tr_out),                # DOUBLE (EUR)
            "tr_in_count": as_int_or_zero(tr_in_count),            # LONG
            "tr_out_count": as_int_or_zero(tr_out_count),          # LONG
            "tr_in_seg": float(tr_in_seg),          # DOUBLE (EUR)
            "tr_out_seg": float(tr_out_seg),        # DOUBLE (EUR)
            "tr_in_seg_count": as_int_or_zero(tr_in_seg_count),    # LONG
            "tr_out_seg_count": as_int_or_zero(tr_out_seg_count),  # LONG
            "trx_from_date": trx_from_date,         # TIMESTAMP | None
            "trx_to_date": trx_to_date,             # TIMESTAMP | None
            "tm": tm,
            "scrn": scrn,
            "customer_country": customer_country,   # STRING | None
            "wallet_age_days": as_int_or_zero(wallet_age_days),     # LONG
            "uses_crypto": bool(uses_crypto),       # BOOLEAN
            "has_high_risk_country_tx": bool(has_high_risk_country_tx),  # BOOLEAN
            "kyc_risk_level": kyc_risk_level,       # STRING
            "effective_date": context.execution_date,  # TIMESTAMP
        }

        customer_insights_data.append(row)

    # ─────────────────────────────────────────────────────────────────────────────
    # DataFrame con el schema exacto del DataSet
    # ─────────────────────────────────────────────────────────────────────────────
    customer_insights_ds = next(
        ds for ds in context.solution.datasets if ds.identifier == "demo_digital_wallets_customer_insights"
    )

    customer_insights_schema = DatasetSchemaHandler(
        customer_insights_ds, context, data_environment=DataEnvironment.get_default()
    )._build_dataset_schema()

    # Filtramos solo campos definidos en el DataSet (evita choque si hay extras)
    field_names = [f.identifier for f in customer_insights_ds.field_list]
    customer_insights_schema = StructType([s for s in customer_insights_schema if s.name in field_names])

    customer_insights_df = spark.createDataFrame(customer_insights_data, schema=customer_insights_schema)
    return customer_insights_df


In [6]:
customer_insights_df = build_customer_insights_df(
    context=context,
    spark=spark,
    dsn_cdd=dsn_cdd
)

In [11]:
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', None)  # muestra el texto completo

customer_insights_df.toPandas().head(3)

Unnamed: 0,client_id,client_name,kyc_classification,kyc_name,kyc_is_new,kyc_recently_updated,kyc_newly_incorporation,kyc_new_customer,kyc_occupation,hr_cc,mr_cc,lr_cc,director_ad,company_ad,tr_in,tr_out,tr_in_count,tr_out_count,tr_in_seg,tr_out_seg,tr_in_seg_count,tr_out_seg_count,trx_from_date,trx_to_date,tm,scrn,customer_country,wallet_age_days,uses_crypto,has_high_risk_country_tx,kyc_risk_level,effective_date
0,32164758029,,Medium,Jean Rossi,True,False,True,True,Analyst,[],[],[],"{""CC"": ""BE"", ""AD"": ""Belgium - 5429 Main St"", ""CL"": ""L""}","{""CC"": ""BE"", ""AD"": ""Belgium - 5429 Main St"", ""CL"": ""L""}",4253.74,3915.49,37,33,2205.357554,2098.709883,22,15,NaT,NaT,"{""Open"": 0, ""Closed"": 0, ""False_positives"": 0}","{""Open"": 0, ""Closed"": 0, ""False_positives"": 0}",BE,0,False,False,Low,1970-02-01
1,35340362714,,Medium,Rafael Romano,True,False,False,True,Self-employed,[],[],[],"{""CC"": ""IE"", ""AD"": ""Ireland - 1054 Main St"", ""CL"": ""L""}","{""CC"": ""IE"", ""AD"": ""Ireland - 1054 Main St"", ""CL"": ""L""}",5297.64,7398.57,30,40,3820.330345,3131.654883,16,30,NaT,NaT,"{""Open"": 0, ""Closed"": 0, ""False_positives"": 0}","{""Open"": 0, ""Closed"": 0, ""False_positives"": 0}",IE,0,False,False,Low,1970-02-01
2,3585942724218,,Medium,Luisa Garcia,True,False,True,True,Consultant,[],[],[],"{""CC"": ""FI"", ""AD"": ""Finland - 4465 Main St"", ""CL"": ""L""}","{""CC"": ""FI"", ""AD"": ""Finland - 4465 Main St"", ""CL"": ""L""}",4352.8,4155.47,51,43,2135.633841,3177.627996,32,26,NaT,NaT,"{""Open"": 0, ""Closed"": 0, ""False_positives"": 0}","{""Open"": 0, ""Closed"": 0, ""False_positives"": 0}",FI,0,False,False,Low,1970-02-01


In [None]:
dataset_functions.write(context, 
                        customer_insights_df, 
                        customer_insights_dataset().identifier,
                        data_environment=DataEnvironment.PUBLIC
)
dataset_functions.publish(context, 
                          customer_insights_dataset().identifier,
                          data_environment=DataEnvironment.PUBLIC
)

In [None]:
context.close()