### Init Context

In [2]:
from thetaray.api.context import init_context
from datetime import datetime
from datetime import timedelta
import yaml

import logging
logging.basicConfig(level=logging.DEBUG, format='%(message)s')

with open('/thetaray/git/solutions/domains/demo_merchant/config/spark_config.yaml') as spark_config_file:
    spark_config = yaml.load(spark_config_file, yaml.FullLoader)['spark_config_a']
    
context = init_context(
    execution_date=datetime(1970, 2, 1),
    #spark_conf=spark_config,
    spark_conf=spark_config, # quitar
    # spark_master='local[*]', # quitar
)

2025-08-28 13:02:53,127:INFO:thetaray.common.logging:start loading solution.....[ load_risks=True , solution_path=/thetaray/git/solutions/domains , settings_path=/thetaray/git/solutions/settings ]
2025-08-28 13:02:53,570:INFO:thetaray.common.logging:load_risks took: 0.12935304641723633
2025-08-28 13:02:54,279:INFO:thetaray.common.logging:=== Started updating schema ===


### Imports

In [3]:
from thetaray.api.dataset import dataset_functions

from domains.demo_merchant.datasets.customers import customers_dataset
from domains.demo_merchant.datasets.transactions import transactions_dataset
from domains.demo_merchant.datasets.customer_insights import customer_insights_dataset
from domains.demo_merchant.evaluation_flows.ef import evaluation_flow

import json
import psycopg2
import os
import random
import pandas as pd
import numpy as np

from datetime import datetime
from faker import Faker
from pyspark.sql import functions as f
from pyspark.sql.types import StructType

from thetaray.api.dataset.schema import DatasetSchemaHandler
from thetaray.common import Constants, Settings
from thetaray.common.data_environment import DataEnvironment

spark = context.get_spark_session()

ns_suffix = Settings.SHARED_NAMESPACE.removeprefix('shared-')

fake = Faker()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/08/28 13:02:59 WARN MetricsConfig: Cannot locate configuration: tried hadoop-metrics2-s3a-file-system.properties,hadoop-metrics2.properties
Hive Session ID = 33275bc2-bc34-4c68-af9b-72263008c63f
25/08/28 13:03:03 INFO SessionState: Hive Session ID = 33275bc2-bc34-4c68-af9b-72263008c63f
                                                                                

### Creation

In [4]:
DB_HOST = Settings.DB_HOST

DB_USER_CDD = os.environ['CDD_POSTGRES_USERNAME']
DB_PASS_CDD = os.environ['CDD_POSTGRES_PASSWORD']
DB_USER_RP = 'postgres'
DB_PASS_RP = 'postgres'


dsn_cdd = (
    f'user={DB_USER_CDD} '
    f'password={DB_PASS_CDD} '
    f'dbname={Constants.CDD_DB_NAME} '
    f'host={DB_HOST[:-5]} '
    f'port={DB_HOST[-4:]} '
    'sslmode=verify-ca '
    'sslrootcert=/certs/ca.crt'
)


dsn_rp = (
    f'user={DB_USER_RP} '
    f'password={DB_PASS_RP} '
    f'dbname={Constants.CDD_DB_NAME} '
    f'host={DB_HOST[:-5]} '
    f'port={DB_HOST[-4:]} '
    'sslmode=verify-ca '
    'sslrootcert=/certs/ca.crt'
)


def execute_query(query, dsn):
    conn = psycopg2.connect(dsn=dsn)
    with conn.cursor() as cursor:
        cursor.execute(query)
        columns = [col.name for col in cursor.description]
        rows = []
        for row in cursor.fetchall():
            rows.append({col: val for col, val in zip(columns, row)})
        return rows

def get_alert_mapper(solution, ef_id):
    schema = f'apps_{ns_suffix.replace("-", "_")}'
    for alert_mapper in execute_query(f'SELECT * FROM {schema}.rp_mappers', dsn_rp):
        ef_unit = json.loads(alert_mapper['solution_evaluation_flow_unit'])
        if not ef_unit:
            continue
        ef_unit = ef_unit[0]
        if ef_unit['solutionId'] == solution and ef_unit['evaluationFlowId'] == ef_id:
            return alert_mapper


def get_alerts(solution, ef_id):
    schema = f'apps_{ns_suffix.replace("-", "_")}'
    alert_mapper = get_alert_mapper(solution, ef_id)
    if alert_mapper is None:
        raise Exception(f'Alert mapper not found for {solution = } and {ef_id = }')
    alert_mapper_identifier = alert_mapper['identifier']
    alert_fields = execute_query(f'SELECT * FROM {schema}.rp_alert_fields', dsn_rp)
    alert_fields = {alert_field['rp_alert_id']: alert_field for alert_field in alert_fields}
    alerts = execute_query(f"SELECT * FROM {schema}.rp_alerts WHERE alert_mapper_identifier = '{alert_mapper_identifier}' AND history_type = 'CURRENT'", dsn_rp)
    for alert in alerts:
        alert['merchant_id'] = alert_fields[alert['alert_id']]['merchant_id']
    return alerts


def get_accounts(solution):
    schema = Constants.SOLUTION_SCHEMA_TPL.format(solution=solution)
    query = f"SELECT * FROM {schema}.demo_merchant_customers"
    return execute_query(query, dsn_cdd)


def get_account_records(solution, merchant_id):
    schema = Constants.SOLUTION_SCHEMA_TPL.format(solution=solution)
    query = f"SELECT * FROM {schema}.demo_merchant_customers WHERE merchant_id = '{merchant_id}'"
    return execute_query(query, dsn_cdd)


def get_account_transactions(solution, merchant_id):
    schema = Constants.SOLUTION_SCHEMA_TPL.format(solution=solution)
    query = f"SELECT * FROM {schema}.demo_merchant_transactions WHERE merchant_id = '{merchant_id}'"
    return execute_query(query, dsn_cdd)

In [16]:
def build_customer_insights_df(context, spark, dsn_cdd):
    import json, math, random
    import numpy as np
    from datetime import datetime, timedelta
    from collections import defaultdict
    from pyspark.sql.types import StructType
    from thetaray.common import Constants, Settings
    from thetaray.api.dataset.schema import DatasetSchemaHandler
    from thetaray.common.data_environment import DataEnvironment

    solution = Settings.SOLUTION
    schema = Constants.SOLUTION_SCHEMA_TPL.format(solution=solution)

    def _as_int(x):
        try:
            return int(x or 0)
        except Exception:
            return 0

    def _as_float(x):
        try:
            if x is None or (isinstance(x, str) and not x.strip()):
                return 0.0
            return float(x)
        except Exception:
            return 0.0

    def _ym_from_any(v):
        if v is None:
            return None
        if isinstance(v, datetime):
            return f"{v.year:04d}-{v.month:02d}"
        s = str(v)
        if len(s) >= 7 and s[4] == "-":
            return s[:7]
        return None

    def _parse_any_ts(v):
        if v is None:
            return None
        if isinstance(v, datetime):
            return v
        s = str(v).strip()
        if len(s) == 7 and s[4] == "-":
            try:
                y, m = map(int, s.split("-"))
                return datetime(y, m, 1)
            except Exception:
                return None
        for fmt in ("%Y-%m-%d", "%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%dT%H:%M:%S.%f"):
            try:
                return datetime.strptime(s, fmt)
            except Exception:
                pass
        try:
            return datetime.fromisoformat(s[:19])
        except Exception:
            return None

    kyc_rows = execute_query(f"SELECT * FROM {schema}.demo_merchant_customers", dsn_cdd) or []
    tx_rows  = execute_query(f"SELECT * FROM {schema}.demo_merchant_transactions", dsn_cdd) or []
    try:
        monthly_rows = execute_query(f"SELECT * FROM {schema}.demo_merchant_customer_monthly", dsn_cdd) or []
    except Exception:
        monthly_rows = []

    by_merchant_kyc = {}
    for r in kyc_rows:
        mid = str(r.get("merchant_id") or "").strip()
        if mid:
            by_merchant_kyc[mid] = r

    by_mid_month = defaultdict(lambda: {
        "observed_sales": 0.0,
        "txn_count": 0,
        "refunds": 0,
        "low_value_cnt": 0,
        "amounts": [],
        "datetimes": []
    })
    LOW_VALUE_THRESHOLD = 5.0

    for t in tx_rows:
        mid = str(t.get("merchant_id") or "").strip()
        if not mid:
            continue
        ts = (t.get("transaction_datetime") or t.get("transaction_ts") or t.get("created_at") or t.get("timestamp"))
        ym = _ym_from_any(ts)
        if ym is None:
            continue
        amt = _as_float(t.get("amount"))
        is_refund = bool(t.get("is_refund")) if "is_refund" in t else False
        key = (mid, ym)
        agg = by_mid_month[key]
        agg["observed_sales"] += float(amt)
        agg["txn_count"] += 1
        agg["refunds"] += (1 if is_refund else 0)
        agg["low_value_cnt"] += (1 if amt <= LOW_VALUE_THRESHOLD else 0)
        agg["amounts"].append(float(amt))
        ts_dt = ts if isinstance(ts, datetime) else _parse_any_ts(ts)
        if ts_dt:
            agg["datetimes"].append(ts_dt)

    for _, agg in by_mid_month.items():
        agg["avg_ticket"] = (agg["observed_sales"] / agg["txn_count"]) if agg["txn_count"] > 0 else 0.0

    def _key_safe(r):
        for k in ("year_month", "period", "ds", "effective_date", "snapshot_ts", "created_at", "created_ts"):
            if k in r and r.get(k) is not None:
                ts = _parse_any_ts(r.get(k))
                if ts:
                    return ts
        return datetime.min

    latest_monthly_by_mid = {}
    if monthly_rows:
        tmp = {}
        for r in monthly_rows:
            mid = str(r.get("merchant_id") or r.get("client_id") or "").strip()
            if not mid:
                continue
            tmp.setdefault(mid, []).append(r)
        for mid, rows in tmp.items():
            rows_sorted = sorted(rows, key=_key_safe)
            latest_monthly_by_mid[mid] = rows_sorted[-1] if rows_sorted else {}

    def _low_value_ratio(agg):
        denom = float(agg["txn_count"] or 0)
        return float(agg["low_value_cnt"]) / denom if denom > 0.0 else 0.0

    def _refund_ratio(agg):
        denom = float(agg["txn_count"] or 0)
        return float(agg["refunds"]) / denom if denom > 0.0 else 0.0

    def _rapid_flag(datetimes, amounts):
        if not datetimes or not amounts:
            return 0.0
        rows = sorted(zip(datetimes, amounts), key=lambda x: x[0])
        by_day = defaultdict(list)
        for ts, amt in rows:
            by_day[ts.date()].append((ts, amt))
        SMALL = 4.99
        LARGE_MULT = 50.0
        WINDOW_MIN = 240
        for _, lst in by_day.items():
            lst.sort(key=lambda x: x[0])
            small_ts = [ts for ts, amt in lst if amt <= SMALL]
            if len(small_ts) >= 12:
                first_small = small_ts[0]
                window_end = first_small + timedelta(minutes=WINDOW_MIN)
                for ts, amt in lst:
                    if first_small <= ts <= window_end and amt >= SMALL * LARGE_MULT:
                        return 1.0
        return 0.0

    timeline = defaultdict(list)
    for (mid, ym), agg in by_mid_month.items():
        timeline[mid].append((ym, agg))
    for mid in timeline:
        timeline[mid].sort(key=lambda p: p[0])

    SPIKE_WINDOW = 6
    SPIKE_THRESH = 1.8
    DORMANT_LOOKBACK = 3
    AVG_TKT_TRAIL = 6

    rows_for_df = []

    for mid, entries in timeline.items():
        if not entries:
            continue

        kyc = by_merchant_kyc.get(mid, {})
        legal_name    = (kyc.get("legal_name") or "").strip()
        business_name = (kyc.get("business_name") or "").strip()
        kyc_name_val  = legal_name or business_name or "Unknown"
        risk_score_val = _as_float(kyc.get("risk_score"))
        mcc = kyc.get("mcc")
        mcc_desc = kyc.get("mcc_description")
        state = kyc.get("state")

        entries_sorted = sorted(entries, key=lambda p: p[0])
        last_ym, last_agg = entries_sorted[-1]
        history = entries_sorted[:-1]

        def _minmax_ts(all_aggs):
            ts_all = []
            for _, ag in all_aggs:
                ts_all.extend(ag.get("datetimes", []))
            if not ts_all:
                return None, None
            return min(ts_all), max(ts_all)

        trx_from_dt, trx_to_dt = _minmax_ts(entries_sorted)

        prev_counts = [float(ag["txn_count"]) for _, ag in history]
        prev_avg_t  = [float(ag["avg_ticket"]) for _, ag in history]

        base_counts = prev_counts[-SPIKE_WINDOW:] if prev_counts else []
        mean_base = (sum(base_counts) / len(base_counts)) if base_counts else 0.0

        cur_txn_count = float(last_agg["txn_count"])
        spike_of_trx = 1.0 if (mean_base > 1.0 and cur_txn_count > mean_base * SPIKE_THRESH) else 0.0

        prev_for_dormant = prev_counts[-DORMANT_LOOKBACK:] if len(prev_counts) >= DORMANT_LOOKBACK else []
        is_dormant_account = 1.0 if (len(prev_for_dormant) == DORMANT_LOOKBACK and all(c == 0.0 for c in prev_for_dormant) and cur_txn_count > 0.0) else 0.0

        base_avg = prev_avg_t[-AVG_TKT_TRAIL:] if prev_avg_t else []
        med_base = float(np.median(base_avg)) if base_avg else 0.0
        avg_txn_amt_ratio = (float(last_agg["avg_ticket"]) / med_base) if med_base > 0.0 else 1.0

        observed_sales = float(last_agg["observed_sales"])
        mrow = latest_monthly_by_mid.get(mid, {})

        low_value_trx_ratio = float(mrow.get("low_value_trx_ratio")) if "low_value_trx_ratio" in mrow else _low_value_ratio(last_agg)
        refund_count_ratio  = float(mrow.get("refund_count_ratio"))  if "refund_count_ratio"  in mrow else _refund_ratio(last_agg)
        rapid_load_transfer = float(mrow.get("rapid_load_transfer")) if "rapid_load_transfer" in mrow else _rapid_flag(last_agg.get("datetimes", []), last_agg.get("amounts", []))

        declared = _as_float(kyc.get("monthly_volume_declared"))
        revenue_mismatch = abs(declared - observed_sales) / observed_sales if observed_sales > 0.0 else 0.0

        rnd = random.Random(str(mid))
        tr_in_seg_value = float(observed_sales * rnd.uniform(0.4, 0.8))
        tr_out_seg_value = 0.0
        tr_in_seg_count_val = int(min(cur_txn_count, math.floor(cur_txn_count * rnd.uniform(0.4, 0.8))))
        tr_out_seg_count_val = 0

        cc_val = (kyc.get("customer_country") or "US")
        ad_line = (kyc.get("address_line") or "")
        director_ad = json.dumps({"CC": cc_val, "AD": ad_line, "CL": "Medium"}, ensure_ascii=False)
        company_ad  = json.dumps({"CC": cc_val, "AD": ad_line, "CL": "Medium"}, ensure_ascii=False)

        row = {
            "merchant_id": mid,
            "legal_name": legal_name,
            "kyc_name": kyc_name_val,
            "kyc_classification": "Medium",
            "kyc_is_new": False,
            "kyc_recently_updated": True,
            "kyc_newly_incorporation": False,
            "kyc_new_customer": False,
            "kyc_occupation": "",
            "hr_cc": [],
            "mr_cc": [],
            "lr_cc": [],
            "director_ad": director_ad,
            "company_ad": company_ad,
            "tr_in": float(observed_sales),
            "tr_out": float(0.0),
            "tr_in_count": _as_int(cur_txn_count),
            "tr_out_count": 0,
            "tr_in_seg": float(tr_in_seg_value),
            "tr_out_seg": float(tr_out_seg_value),
            "tr_in_seg_count": _as_int(tr_in_seg_count_val),
            "tr_out_seg_count": _as_int(tr_out_seg_count_val),
            "trx_from_date": trx_from_dt,
            "trx_to_date":   trx_to_dt,
            "tm": json.dumps({"Open": 7, "Closed": 0, "False_positives": 0}, ensure_ascii=False),
            "scrn": json.dumps({"Open": 0, "Closed": 0, "False_positives": 0}, ensure_ascii=False),
            "customer_country": cc_val,
            "wallet_age_days": _as_int(0),
            "uses_crypto": False,
            "has_high_risk_country_tx": False,
            "kyc_risk_level": "Medium",
            "effective_date": context.execution_date,
            "low_value_trx_ratio": float(low_value_trx_ratio),
            "is_dormant_account": float(is_dormant_account),
            "spike_of_trx": float(spike_of_trx),
            "refund_count_ratio": float(refund_count_ratio),
            "revenue_mismatch": float(revenue_mismatch),
            "avg_txn_amt_ratio": float(avg_txn_amt_ratio),
            "rapid_load_transfer": float(rapid_load_transfer),
            "mcc": mcc,
            "mcc_description": mcc_desc,
            "state": state,
            "risk_score": float(risk_score_val),
            "monthly_volume_declared": float(declared),
            "average_ticket_declared": float(_as_float(kyc.get("average_ticket_declared"))),
        }
        rows_for_df.append(row)

    ds = customer_insights_dataset()
    ds_schema = DatasetSchemaHandler(ds, context, data_environment=DataEnvironment.get_default())._build_dataset_schema()
    field_names = [f.identifier for f in ds.field_list]
    filtered_rows = [{k: r.get(k, None) for k in field_names} for r in rows_for_df]
    ds_schema = StructType([s for s in ds_schema if s.name in field_names])
    customer_insights_df = spark.createDataFrame(filtered_rows, schema=ds_schema)
    return customer_insights_df


In [17]:
customer_insights_df = build_customer_insights_df(
    context=context,
    spark=spark,
    dsn_cdd=dsn_cdd
)



In [None]:
dataset_functions.write(context, 
                        customer_insights_df, 
                        customer_insights_dataset().identifier,
                        data_environment=DataEnvironment.PUBLIC
)
dataset_functions.publish(context, 
                          customer_insights_dataset().identifier,
                          data_environment=DataEnvironment.PUBLIC
)

In [None]:
context.close()