### Init Context

In [2]:
import logging
import uuid
import random
import numpy as np
import pandas as pd

import yaml
from datetime import datetime
from faker import Faker
from thetaray.api.context import init_context
from pyspark.sql import functions as f
from thetaray.common.data_environment import DataEnvironment


# Configuración logging
logging.basicConfig(level=logging.DEBUG, format='%(message)s')

# Configuración pandas
pd.set_option('display.max_columns', None)

# Cargar configuración Spark
with open('/thetaray/git/solutions/domains/demo_merchant/config/spark_config.yaml') as spark_config_file:
    spark_config = yaml.load(spark_config_file, yaml.FullLoader)['spark_config_a']


# Inicializar contexto
context = init_context(
    execution_date=datetime(1970, 2, 1),
    spark_conf=spark_config,
    # spark_master='local[*]', # drop
    allow_type_changes=True,
    drop_undefined_datasets=True,
    delete_unused_columns=True
)

2025-08-26 15:24:58,307:INFO:thetaray.common.logging:start loading solution.....[ load_risks=True , solution_path=/thetaray/git/solutions/domains , settings_path=/thetaray/git/solutions/settings ]
2025-08-26 15:24:58,853:ERROR:thetaray.common.logging:failed to load solution
Traceback (most recent call last):
  File "/thetaray/platform/python/thetaray/api/solution/loader.py", line 100, in _load_solution
    smd_dict['risk_manager'] = load_risk(smd_dict['datasets'], smd_dict['evaluation_flows'],smd_dict['enrichments_manager'], user_enrichments_manager, Settings.DOMAINS_PATH, disabled_domains)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/thetaray/platform/python/thetaray/api/solution/risks_manager.py", line 831, in load_risk
    risk_manager = RisksManager(
                   ^^^^^^^^^^^^^
  File "/thetaray/platform/python/thetaray/api

ValueError: Traceback (most recent call last):
  File "/thetaray/platform/python/thetaray/api/solution/loader.py", line 100, in _load_solution
    smd_dict['risk_manager'] = load_risk(smd_dict['datasets'], smd_dict['evaluation_flows'],smd_dict['enrichments_manager'], user_enrichments_manager, Settings.DOMAINS_PATH, disabled_domains)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/thetaray/platform/python/thetaray/api/solution/risks_manager.py", line 831, in load_risk
    risk_manager = RisksManager(
                   ^^^^^^^^^^^^^
  File "/thetaray/platform/python/thetaray/api/solution/risks_manager.py", line 77, in __init__
    self._load_risks()
  File "/thetaray/platform/python/thetaray/api/solution/risks_manager.py", line 534, in _load_risks
    validate_risks(
  File "/thetaray/platform/python/thetaray/api/solution/risk_validator.py", line 51, in validate_risks
    _validate_conditions(risk_manager, enrichments_manager, risks, datasets, evaluation_flows)
  File "/thetaray/platform/python/thetaray/api/solution/risk_validator.py", line 244, in _validate_conditions
    field_names = _get_field_names(
                  ^^^^^^^^^^^^^^^^^
  File "/thetaray/platform/python/thetaray/api/solution/risk_validator.py", line 177, in _get_field_names
    input_ds: DataSet = get_ds_metadata_by_identifier_from_list(
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/thetaray/platform/python/thetaray/api/solution/metadata_functions.py", line 30, in get_ds_metadata_by_identifier_from_list
    raise ValueError("Metadata for dataset: " + identifier + " does not exist")
ValueError: Metadata for dataset: demo_nested_banking_customer_monthly does not exist


### Imports

In [None]:
from thetaray.api.dataset import dataset_functions

from domains.demo_merchant.datasets.transactions import transactions_dataset
from domains.demo_merchant.datasets.customer_monthly import customer_monthly_dataset
# from domains.demo_merchant.datasets.customer_insights import customer_insights_dataset 
from domains.demo_merchant.datasets.customers import customers_dataset 

# Data Gen

### 1. Transactions and Aggregate Features Generation

In [None]:
from __future__ import annotations

import numpy as np
import pandas as pd
import random
from itertools import count
from datetime import datetime, timedelta
from typing import Optional, Tuple, List, Dict

# -------------------- US market basics --------------------
US_STATES = [
    "AL","AK","AZ","AR","CA","CO","CT","DE","FL","GA","HI","ID","IL","IN","IA","KS","KY","LA","ME","MD",
    "MA","MI","MN","MS","MO","MT","NE","NV","NH","NJ","NM","NY","NC","ND","OH","OK","OR","PA","RI","SC",
    "SD","TN","TX","UT","VT","VA","WA","WV","WI","WY","DC","PR"
]

MCC_LIST = [
    ("5411","Grocery Stores, Supermarkets"),
    ("5812","Eating Places, Restaurants"),
    ("5814","Fast Food Restaurants"),
    ("4111","Local Passenger Transportation"),
    ("5732","Electronics Stores"),
    ("5691","Clothing Stores"),
    ("5942","Book Stores"),
    ("4814","Telecommunication Services"),
    ("5977","Cosmetic Stores"),
    ("5999","Specialty Retail Stores"),
]

CARD_BRANDS = ["Visa","Mastercard","Amex","Discover"]
CHANNELS = ["card_present","ecommerce"]
BUSINESS_NAMES = [
    "River Market","Sunset Electronics","Liberty Diner","Bluebird Books","Metro Rides",
    "Pioneer Outfitters","Cedar Grocery","North Star Grill","Harbor Cafe","Prairie Style"
]

# -------------------- helpers --------------------
def _month_starts(end_date: datetime, months_total: int) -> List[datetime]:
    anchor = datetime(end_date.year, end_date.month, 1)
    return [(anchor - pd.DateOffset(months=m)).to_pydatetime()
            for m in range(months_total - 1, -1, -1)]

def _pick_active_months(months_total: int, active_months: int) -> List[int]:
    m = min(active_months, months_total)
    return sorted(random.sample(range(months_total), m))

def _next_id(prefix: str, counter: count) -> str:
    return f"{prefix}{next(counter):010d}"

def _month_key(dt: datetime) -> str:
    return dt.strftime("%Y-%m")

def _detect_rapid_pattern(df_month: pd.DataFrame,
                          small_amt_thresh: float,
                          small_repeats: int,
                          large_multiple: float,
                          window_minutes: int = 240) -> float:
    """Return 1.0 if pattern exists, otherwise 0.0 (DOUBLE)."""
    if df_month.empty:
        return 0.0
    d = df_month.sort_values("transaction_datetime").copy()
    d["date"] = d["transaction_datetime"].dt.date
    for _, g in d.groupby("date"):
        g = g.sort_values("transaction_datetime")
        small_mask = g["amount"] <= small_amt_thresh
        small_times = g.loc[small_mask, "transaction_datetime"].tolist()
        if len(small_times) >= small_repeats:
            first_small = small_times[0]
            window_end = first_small + timedelta(minutes=window_minutes)
            in_window = g[g["transaction_datetime"].between(first_small, window_end)]
            if not in_window.empty:
                large_thresh = small_amt_thresh * large_multiple
                if (in_window["amount"] >= large_thresh).any():
                    return 1.0
    return 0.0

# -------------------- main generator --------------------
def generate_ma_fake_transactions(
    # --- scale & horizon ---
    n_merchants: int = 200,
    months_total: int = 12,
    active_months_per_merchant: int = 12,
    end_date: Optional[datetime] = None,
    avg_txns_per_active_month: float = 800.0,
    currency: str = "USD",
    seed: Optional[int] = 42,

    # --- TARGET: features that SHOULD NOT trigger (kept calm) ---
    ensure_no_mismatch: bool = True,            # revenue_mismatch calm
    ensure_no_avg_ticket_shift: bool = True,    # avg_txn_amt_ratio calm
    ensure_no_rapid_pattern: bool = True,       # rapid_load_transfer calm

    # --- knobs for calm features ---
    declared_revenue_noise_std: float = 0.02,   # ~2% noise => low mismatch
    avg_ticket_monthly_jitter: float = 0.05,    # ~5% drift between months
    rapid_small_txn_amount: float = 4.99,
    rapid_small_repeats: int = 12,
    rapid_large_multiplier: float = 50.0,
    rapid_window_minutes: int = 240,

    # --- TARGET: features that SHOULD trigger (inject anomalies by default) ---
    # Low value ratio
    low_value_threshold: float = 5.00,
    pct_merchants_low_value_anom: float = 0.15,     # fraction of merchants with at least one month high ratio
    low_value_ratio_target: float = 0.40,           # target ratio in an anomalous month
    # Dormant account
    dormant_lookback_months: int = 3,
    pct_merchants_dormant_anom: float = 0.12,       # merchants that wake up after dormancy
    # Spike of Transactions
    spike_baseline_window: int = 6,
    spike_ratio_threshold: float = 1.8,
    pct_merchants_spike_anom: float = 0.15,         # merchants with 1 spiky month
    spike_volume_multiplier: float = 2.2,           # multiplies lambda in spike month
    # Refund count ratio
    refund_rate_base: float = 0.008,                # ~0.8% normal
    chargeback_rate_base: float = 0.001,            # optional (not in feature)
    pct_merchants_refund_anom: float = 0.14,        # merchants with high refund month
    refund_rate_anom: float = 0.05,                 # 5% refunds in anomalous month

) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Returns:
        transactions_df: one row per transaction
        mm: monthly features per merchant with EXACT columns (all DOUBLE):
            merchant_id, year_month, low_value_trx_ratio, is_dormant_account,
            spike_of_trx, refund_count_ratio, revenue_mismatch, avg_txn_amt_ratio,
            rapid_load_transfer

    Trigger policy implemented:
      - TRIGGER (Yes): low_value_trx_ratio↑, is_dormant_account=1.0 in wake-up month,
                       spike_of_trx=1.0 in spiky month, refund_count_ratio↑
      - NO TRIGGER (No): revenue_mismatch≈low, avg_txn_amt_ratio≈1, rapid_load_transfer=0.0
    """
    # seed
    if seed is not None:
        np.random.seed(seed); random.seed(seed)
    if end_date is None:
        end_date = datetime.utcnow()

    # months
    month_starts = _month_starts(end_date, months_total)
    month_ranges = [(ms, (ms + pd.DateOffset(months=1)).to_pydatetime() - timedelta(seconds=1))
                    for ms in month_starts]

    # merchants
    m_ids = [f"M{100000 + i}" for i in range(n_merchants)]
    m_names = {m: random.choice(BUSINESS_NAMES) + f" #{i%97+1}" for i, m in enumerate(m_ids)}
    m_state = {m: random.choice(US_STATES) for m in m_ids}
    m_mcc, m_mcc_desc = zip(*[random.choice(MCC_LIST) for _ in m_ids])
    m_mcc = {m: m_mcc[i] for i, m in enumerate(m_ids)}
    m_mcc_desc = {m: m_mcc_desc[i] for i, m in enumerate(m_ids)}

    # per-merchant base avg ticket (by MCC)
    m_base_ticket: Dict[str, float] = {}
    for m in m_ids:
        mcc = m_mcc[m]
        if mcc in {"5812","5814"}:       base = np.random.lognormal(np.log(18), 0.25)
        elif mcc in {"5411"}:            base = np.random.lognormal(np.log(32), 0.25)
        elif mcc in {"5732"}:            base = np.random.lognormal(np.log(85), 0.30)
        elif mcc in {"4111"}:            base = np.random.lognormal(np.log(14), 0.22)
        else:                             base = np.random.lognormal(np.log(28), 0.28)
        m_base_ticket[m] = max(3.0, float(base))

    # baseline txn volume per merchant
    m_base_lambda = {m: max(30.0, np.random.normal(avg_txns_per_active_month, avg_txns_per_active_month*0.2))
                     for m in m_ids}

    # choose anomaly cohorts (for TRIGGER features)
    def pick_subset(pct: float) -> set:
        k = max(1, int(round(pct * n_merchants))) if n_merchants else 0
        return set(random.sample(m_ids, min(k, n_merchants))) if k > 0 else set()

    cohort_lowval   = pick_subset(pct_merchants_low_value_anom)
    cohort_dormant  = pick_subset(pct_merchants_dormant_anom)
    cohort_spike    = pick_subset(pct_merchants_spike_anom)
    cohort_refund   = pick_subset(pct_merchants_refund_anom)

    # For "No trigger" trio
    rapid_merchants = set() if ensure_no_rapid_pattern else pick_subset(0.08)

    # tracking
    tid = count(1); cid = count(1)
    rows = []
    # pre-aggregation holders
    obs_sales: Dict[Tuple[str, str], float] = {}
    obs_count: Dict[Tuple[str, str], int] = {}
    refunds_count: Dict[Tuple[str, str], int] = {}
    low_value_count: Dict[Tuple[str, str], int] = {}
    # month indices storage for rapid detector
    per_month_txn_idx: Dict[Tuple[str, str], List[int]] = {}

    # For dormant anomaly: force first N months zero, then active
    dormant_cut_idx = max(dormant_lookback_months, 1)

    # choose one anomalous month per cohort (for low value / spike / refund)
    def choose_anom_month() -> int:
        # avoid very first month to have some baseline
        return random.randint(max(1, spike_baseline_window//2), months_total-1)

    lowval_month_by_m  = {m: choose_anom_month() for m in cohort_lowval}
    spike_month_by_m   = {m: choose_anom_month() for m in cohort_spike}
    refund_month_by_m  = {m: choose_anom_month() for m in cohort_refund}

    for m in m_ids:
        base_ticket = m_base_ticket[m]
        base_lambda = m_base_lambda[m]

        # activity plan: if dormant cohort, zero tx for first dormant_cut_idx months then resume
        active_idx = list(range(months_total))  # start with all months potentially active
        if m in cohort_dormant:
            # zero-out first dormant_cut_idx months; activity resumes afterwards
            active_idx = list(range(dormant_cut_idx, months_total))

        for idx, (m_start, m_end) in enumerate(month_ranges):
            ym = _month_key(m_start)
            month_txn_indices: List[int] = []

            # decide activity
            if idx not in active_idx:
                per_month_txn_idx[(m, ym)] = month_txn_indices
                continue  # no transactions

            # monthly drift (calm vs. not)
            if ensure_no_avg_ticket_shift:
                ticket_mu = base_ticket * (1.0 + np.random.normal(0.0, avg_ticket_monthly_jitter))
                lam = max(5.0, np.random.normal(base_lambda, base_lambda * 0.07))
            else:
                ticket_mu = base_ticket * (1.0 + np.random.normal(0.0, max(0.22, avg_ticket_monthly_jitter*3)))
                lam = max(5.0, np.random.normal(base_lambda, base_lambda * 0.25))

            # spike anomaly: boost lambda in the chosen month
            if m in cohort_spike and idx == spike_month_by_m[m]:
                lam *= spike_volume_multiplier

            n_tx = int(max(1, np.random.poisson(lam)))

            # refund rate (anomalous month?)
            month_refund_rate = refund_rate_anom if (m in cohort_refund and idx == refund_month_by_m[m]) else refund_rate_base

            # low-value anomaly: force a fraction of low-amount txns in that month
            make_lowval_heavy = (m in cohort_lowval and idx == lowval_month_by_m[m])
            target_lowval_ratio = low_value_ratio_target if make_lowval_heavy else None
            lowval_needed = int(round(target_lowval_ratio * n_tx)) if target_lowval_ratio else 0

            # generate transactions
            lowvals_assigned = 0
            for t in range(n_tx):
                span_sec = int((m_end - m_start).total_seconds())
                ts = m_start + timedelta(seconds=random.randint(0, span_sec))

                # amount: assign explicit low-values first if anomaly month
                if make_lowval_heavy and lowvals_assigned < lowval_needed:
                    amount = round(np.random.uniform(0.5, low_value_threshold), 2)
                    lowvals_assigned += 1
                else:
                    amount = float(np.random.lognormal(mean=np.log(ticket_mu), sigma=0.35))
                    amount = round(max(0.5, amount), 2)

                is_ref = (np.random.rand() < month_refund_rate)
                is_cbk = (np.random.rand() < chargeback_rate_base)

                rows.append({
                    "transaction_id": _next_id("T", tid),
                    "merchant_id": m,
                    "merchant_name": m_names[m],
                    "mcc": m_mcc[m],
                    "mcc_description": m_mcc_desc[m],
                    "state": m_state[m],
                    "transaction_datetime": ts,
                    "amount": amount,
                    "currency": currency,
                    "channel": random.choice(CHANNELS),
                    "card_brand": random.choice(CARD_BRANDS),
                    "is_refund": is_ref,
                    "is_chargeback": is_cbk,
                    "customer_id": f"C{next(cid):09d}"
                })
                month_txn_indices.append(len(rows)-1)

                key = (m, ym)
                obs_sales[key] = obs_sales.get(key, 0.0) + amount
                obs_count[key] = obs_count.get(key, 0) + 1
                refunds_count[key] = refunds_count.get(key, 0) + (1 if is_ref else 0)
                low_value_count[key] = low_value_count.get(key, 0) + (1 if amount <= low_value_threshold else 0)

            # optional: (NOT trigger by default) rapid pattern
            if m in rapid_merchants:
                base_day = random.randint(3, 25)
                # many small
                for i in range(rapid_small_repeats):
                    ts = m_start + timedelta(days=min(base_day, 27), hours=9, minutes=min(59, 2*i))
                    amt = rapid_small_txn_amount
                    rows.append({
                        "transaction_id": _next_id("T", tid),
                        "merchant_id": m,
                        "merchant_name": m_names[m],
                        "mcc": m_mcc[m],
                        "mcc_description": m_mcc_desc[m],
                        "state": m_state[m],
                        "transaction_datetime": ts,
                        "amount": amt,
                        "currency": currency,
                        "channel": "ecommerce",
                        "card_brand": random.choice(CARD_BRANDS),
                        "is_refund": False,
                        "is_chargeback": False,
                        "customer_id": f"C{next(cid):09d}"
                    })
                    month_txn_indices.append(len(rows)-1)
                    key = (m, ym)
                    obs_sales[key] = obs_sales.get(key, 0.0) + amt
                    obs_count[key] = obs_count.get(key, 0) + 1
                    low_value_count[key] = low_value_count.get(key, 0) + (1 if amt <= low_value_threshold else 0)

                # one big
                ts_big = m_start + timedelta(days=min(base_day, 27), hours=12, minutes=0)
                big_amt = round(rapid_small_txn_amount * rapid_large_multiplier, 2)
                rows.append({
                    "transaction_id": _next_id("T", tid),
                    "merchant_id": m,
                    "merchant_name": m_names[m],
                    "mcc": m_mcc[m],
                    "mcc_description": m_mcc_desc[m],
                    "state": m_state[m],
                    "transaction_datetime": ts_big,
                    "amount": big_amt,
                    "currency": currency,
                    "channel": "ecommerce",
                    "card_brand": random.choice(CARD_BRANDS),
                    "is_refund": False,
                    "is_chargeback": False,
                    "customer_id": f"C{next(cid):09d}"
                })
                month_txn_indices.append(len(rows)-1)
                key = (m, ym)
                obs_sales[key] = obs_sales.get(key, 0.0) + big_amt
                obs_count[key] = obs_count.get(key, 0) + 1
                # big amount not low-value

            per_month_txn_idx[(m, ym)] = month_txn_indices

    # transactions df
    transactions_df = pd.DataFrame(rows)
    if not transactions_df.empty:
        transactions_df["transaction_datetime"] = pd.to_datetime(transactions_df["transaction_datetime"])
        transactions_df["amount"] = transactions_df["amount"].astype(float)
        transactions_df["is_refund"] = transactions_df["is_refund"].astype(bool)
        transactions_df["is_chargeback"] = transactions_df["is_chargeback"].astype(bool)
        transactions_df = transactions_df.sort_values("transaction_datetime").reset_index(drop=True)

    # create full key grid merchant x month
    year_month = [_month_key(ms) for ms in month_starts]
    key_grid = pd.MultiIndex.from_product([m_ids, year_month], names=["merchant_id","year_month"]).to_frame(index=False)

    # observed aggregates
    def _get(obs_dict, k, default): return obs_dict.get(k, default)

    key_grid["observed_sales"] = key_grid.apply(lambda r: round(_get(obs_sales, (r["merchant_id"], r["year_month"]), 0.0), 2), axis=1)
    key_grid["txn_count"]      = key_grid.apply(lambda r: int(_get(obs_count, (r["merchant_id"], r["year_month"]), 0)), axis=1)
    key_grid["avg_ticket"]     = (key_grid["observed_sales"] / key_grid["txn_count"].replace(0, np.nan)).fillna(0.0).round(2)
    key_grid["refunds"]        = key_grid.apply(lambda r: int(_get(refunds_count, (r["merchant_id"], r["year_month"]), 0)), axis=1)
    key_grid["low_value_cnt"]  = key_grid.apply(lambda r: int(_get(low_value_count, (r["merchant_id"], r["year_month"]), 0)), axis=1)

    # declared revenue (calm vs not)
    if ensure_no_mismatch:
        noise = np.random.normal(0.0, declared_revenue_noise_std, size=len(key_grid))
    else:
        noise = np.random.normal(0.0, max(0.12, declared_revenue_noise_std*6), size=len(key_grid))
    key_grid["declared_revenue"] = (key_grid["observed_sales"] * (1.0 + noise)).round(2)

    # ---------- FEATURES (ALL DOUBLE) ----------
    mm = key_grid[["merchant_id","year_month","observed_sales","txn_count","avg_ticket","declared_revenue","refunds","low_value_cnt"]].copy()

    # 1) low_value_trx_ratio (DOUBLE)
    mm["low_value_trx_ratio"] = (mm["low_value_cnt"] / mm["txn_count"].replace(0, np.nan)).fillna(0.0).astype(float)

    # 2) is_dormant_account (DOUBLE: 1.0 if resumes after N zero months, else 0.0)
    def _dormant_flag(g: pd.DataFrame) -> pd.Series:
        counts = g["txn_count"].values
        out = np.zeros(len(counts), dtype=float)
        for i in range(len(counts)):
            if counts[i] > 0 and i >= dormant_lookback_months:
                prev = counts[i-dormant_lookback_months:i]
                if np.all(prev == 0):
                    out[i] = 1.0
        return pd.Series(out, index=g.index)
    mm["is_dormant_account"] = mm.groupby("merchant_id", group_keys=False).apply(_dormant_flag).astype(float)

    # 3) spike_of_trx (DOUBLE: 1.0 abnormal, else 0.0)
    def _spike_flags(g: pd.DataFrame) -> pd.Series:
        vals = g["txn_count"].astype(float).values
        out = np.zeros(len(vals), dtype=float)
        for i in range(len(vals)):
            start = max(0, i - spike_baseline_window)
            baseline = vals[start:i]
            mean_base = baseline.mean() if baseline.size > 0 else 0.0
            if mean_base > 1.0 and vals[i] > mean_base * spike_ratio_threshold:
                out[i] = 1.0
        return pd.Series(out, index=g.index)
    mm["spike_of_trx"] = mm.groupby("merchant_id", group_keys=False).apply(_spike_flags).astype(float)

    # 4) refund_count_ratio (DOUBLE)
    mm["refund_count_ratio"] = (mm["refunds"] / mm["txn_count"].replace(0, np.nan)).fillna(0.0).astype(float)

    # 5) revenue_mismatch (DOUBLE)  -> keep calm by default
    mm["revenue_mismatch"] = (np.abs(mm["declared_revenue"] - mm["observed_sales"]) /
                              mm["observed_sales"].replace(0, np.nan)).fillna(0.0).astype(float)

    # 6) avg_txn_amt_ratio (DOUBLE) vs trailing median -> ~1.0 when calm
    def _avg_ratio(g: pd.DataFrame) -> pd.Series:
        vals = g["avg_ticket"].astype(float).values
        out = np.ones(len(vals), dtype=float)
        for i in range(len(vals)):
            base = vals[max(0, i - 6):i]  # trailing 6 by default
            med = np.median(base) if base.size > 0 else 0.0
            out[i] = (vals[i] / med) if med > 0 else 1.0
        return pd.Series(out, index=g.index)
    mm["avg_txn_amt_ratio"] = mm.groupby("merchant_id", group_keys=False).apply(_avg_ratio).astype(float)

    # 7) rapid_load_transfer (DOUBLE: 1.0/0.0) -> keep 0.0 by default
    def _rapid_for_row(r):
        m, ym = r["merchant_id"], r["year_month"]
        idx_list = per_month_txn_idx.get((m, ym), [])
        if not idx_list:
            return 0.0
        month_df = transactions_df.loc[idx_list, ["transaction_datetime","amount"]].copy()
        return float(_detect_rapid_pattern(month_df,
                                           small_amt_thresh=rapid_small_txn_amount,
                                           small_repeats=rapid_small_repeats,
                                           large_multiple=rapid_large_multiplier,
                                           window_minutes=rapid_window_minutes))
    mm["rapid_load_transfer"] = mm.apply(_rapid_for_row, axis=1).astype(float)
    
    mm["year_month"] = pd.to_datetime(mm["year_month"], format="%Y-%m")
    mm["year_month_str"] = mm["year_month"].dt.strftime("%Y-%m")
    # final cast/order (all features DOUBLE)
    mm = mm[[
        "merchant_id","year_month_str","year_month",
        "low_value_trx_ratio",
        "is_dormant_account",
        "spike_of_trx",
        "refund_count_ratio",
        "revenue_mismatch",
        "avg_txn_amt_ratio",
        "rapid_load_transfer"
    ]].sort_values(["merchant_id","year_month"]).reset_index(drop=True)

    # ensure float dtype
    for c in ["low_value_trx_ratio","is_dormant_account","spike_of_trx",
              "refund_count_ratio","revenue_mismatch","avg_txn_amt_ratio","rapid_load_transfer"]:
        mm[c] = mm[c].astype(float)

    return transactions_df, mm


In [None]:
# 1)  150 merchants, 18 months
df_trx, agg_df = generate_ma_fake_transactions(
    n_merchants=150,
    months_total=18,
    active_months_per_merchant=12, 
    end_date=datetime(2025, 6, 30),
    avg_txns_per_active_month=35, 
    currency="USD",
    ensure_no_mismatch=True,
    ensure_no_avg_ticket_shift=True,
    ensure_no_rapid_pattern=True,

     # low value
    low_value_threshold=9.0,
    pct_merchants_low_value_anom=0.30,
    low_value_ratio_target=0.50,

    # dormant
    dormant_lookback_months=3,
    pct_merchants_dormant_anom=0.25,

    # spike
    spike_baseline_window=4,
    spike_ratio_threshold=1.6,
    pct_merchants_spike_anom=0.30,
    spike_volume_multiplier=2.6,

    # refunds
    refund_rate_base=0.008,
    chargeback_rate_base=0.001,
    pct_merchants_refund_anom=0.30,
    refund_rate_anom=0.075
)

In [None]:
import numpy as np

np.random.seed(42)

agg_df["revenue_mismatch"] = np.random.normal(loc=1.0, scale=0.02, size=len(agg_df))
agg_df["avg_txn_amt_ratio"] = np.random.normal(loc=0.95, scale=0.015, size=len(agg_df))
agg_df["rapid_load_transfer"] = np.random.normal(loc=0.05, scale=0.01, size=len(agg_df))

# Asegurar que no haya valores extremos (recorte a un rango razonable)
agg_df["revenue_mismatch"] = agg_df["revenue_mismatch"].clip(0.9, 1.1)
agg_df["avg_txn_amt_ratio"] = agg_df["avg_txn_amt_ratio"].clip(0.9, 1.05)
agg_df["rapid_load_transfer"] = agg_df["rapid_load_transfer"].clip(0.0, 0.1)

### 2. Anomalous KYC

In [None]:
import pandas as pd
import numpy as np
import random
import re
from datetime import date, timedelta
from typing import Iterable, Optional

# ------------------ small US catalogs ------------------
US_STATES = ["AL","AK","AZ","AR","CA","CO","CT","DE","FL","GA","HI","ID","IL","IN","IA","KS","KY","LA","ME","MD",
             "MA","MI","MN","MS","MO","MT","NE","NV","NH","NJ","NM","NY","NC","ND","OH","OK","OR","PA","RI","SC",
             "SD","TN","TX","UT","VT","VA","WA","WV","WI","WY","DC","PR"]

US_CITIES = [
    ("New York","NY"),("Los Angeles","CA"),("Chicago","IL"),("Houston","TX"),("Phoenix","AZ"),
    ("Philadelphia","PA"),("San Antonio","TX"),("San Diego","CA"),("Dallas","TX"),("San Jose","CA"),
    ("Austin","TX"),("Jacksonville","FL"),("Fort Worth","TX"),("Columbus","OH"),("Charlotte","NC")
]

MCC_LIST = [
    ("5411","Grocery Stores, Supermarkets"),
    ("5812","Eating Places, Restaurants"),
    ("5814","Fast Food Restaurants"),
    ("4111","Local Passenger Transportation"),
    ("5732","Electronics Stores"),
    ("5691","Clothing Stores"),
    ("5942","Book Stores"),
    ("4814","Telecommunication Services"),
    ("5977","Cosmetic Stores"),
    ("5999","Specialty Retail Stores"),
]

BUSINESS_WORDS = ["River","Sunset","Liberty","Bluebird","Metro","Pioneer","Cedar","North Star","Harbor","Prairie"]
BUSINESS_SUFFIX = ["LLC","Inc.","Corp.","Ltd.","Company","Group","Holdings"]

HIGH_RISK_INDUSTRIES = {"5944","5967","5993","7273","7995","6051","6211"}  # (ejemplos típicos de alto riesgo)
# Nota: si tu demo usa sólo MCC_LIST de arriba, marcamos high-risk mediante un flag independiente.

# ------------------ helpers ------------------
def _rand_business_name(seed_str: str) -> str:
    base = random.choice(BUSINESS_WORDS)
    suf  = random.choice(BUSINESS_SUFFIX)
    return f"{base} {seed_str} {suf}"

def _rand_phone() -> str:
    return f"+1{random.randint(200,999)}{random.randint(200,999)}{random.randint(1000,9999)}"

def _rand_email(bname: str) -> str:
    clean = re.sub(r"[^a-z0-9]+","", bname.lower())
    return f"contact@{clean[:12] or 'merchant'}.com"

def _rand_website(bname: str) -> str:
    clean = re.sub(r"[^a-z0-9]+","", bname.lower())
    return f"https://www.{clean[:15] or 'merchant'}.com"

def _rand_postal_code() -> str:
    return f"{random.randint(10000, 99999)}"

def _rand_address() -> tuple[str,str,str,str]:
    street_no = random.randint(10, 9999)
    street_nm = random.choice(["Main St","Market St","Broadway","1st Ave","2nd Ave","Park Ave","Oak St","Pine St"])
    city, st = random.choice(US_CITIES)
    return (f"{street_no} {street_nm}", city, st, _rand_postal_code())

def _rand_ein(correct: bool = True) -> str:
    if correct:
        # EIN format: NN-NNNNNNN
        return f"{random.randint(10,99)}-{random.randint(1000000,9999999)}"
    # incorrect patterns to simulate anomaly
    choices = [
        f"{random.randint(1,9)}{random.randint(0,9)}{random.randint(1000000,9999999)}",    # missing hyphen
        f"{random.randint(100,999)}-{random.randint(100000,999999)}",                      # wrong length
        f"{random.randint(10,99)}-{random.randint(10000,99999)}{random.choice(['A','X'])}" # non-digit
    ]
    return random.choice(choices)

def _choose_mcc(meta_row: Optional[pd.Series]) -> tuple[str,str]:
    if meta_row is not None and "mcc" in meta_row and pd.notnull(meta_row["mcc"]):
        desc = meta_row.get("mcc_description", "")
        if not desc:
            for code, d in MCC_LIST:
                if code == str(meta_row["mcc"]): desc = d
        return str(meta_row["mcc"]), (desc or "NA")
    code, desc = random.choice(MCC_LIST)
    return code, desc

# ------------------ main ------------------
def generate_merchant_kyc(
    merchant_ids: Optional[Iterable[str]] = None,
    transactions_df: Optional[pd.DataFrame] = None,  # alt: derive ids & meta (merchant_name/mcc/state)
    merchant_meta: Optional[pd.DataFrame] = None,    # optional columns: merchant_id, merchant_name, mcc, mcc_description, state
    seed: int = 42,

    # anomaly knobs (prevalence)
    pct_sanctioned_hit: float = 0.01,
    pct_pep_owner: float = 0.02,
    pct_adverse_media: float = 0.06,
    pct_missing_docs: float = 0.10,
    pct_incorrect_tax_id: float = 0.05,
    pct_po_box_address: float = 0.05,
    pct_high_risk_industry: float = 0.08,

    # declared economics (rough priors)
    avg_ticket_declared_mu: float = 30.0,
    avg_ticket_declared_sigma: float = 0.5,   # lognormal sigma
    monthly_volume_declared_mu: float = 75000.0,
    monthly_volume_declared_sigma: float = 0.7,

    # refund / chargeback policies
    refund_policy_days_choices: tuple = (7, 14, 30),
) -> pd.DataFrame:
    """
    Create simple KYC records for each merchant_id.
    Inputs:
      - Provide either `merchant_ids`, or `transactions_df` (will infer unique merchants),
        and optionally `merchant_meta` to reuse merchant_name/mcc/state from your generator.

    Output:
      One row per merchant with standard MA fields and random anomalies (flags).
    """
    random.seed(seed); np.random.seed(seed)

    # ---- resolve merchant universe ----
    if merchant_ids is None:
        if transactions_df is None:
            raise ValueError("Provide either merchant_ids or transactions_df.")
        merchant_ids = (
            transactions_df[["merchant_id"]]
            .dropna().drop_duplicates()["merchant_id"].astype(str).tolist()
        )
    else:
        merchant_ids = [str(x) for x in merchant_ids]

    meta = None
    if merchant_meta is not None:
        meta = (merchant_meta.copy()
                .drop_duplicates(subset=["merchant_id"])
                .set_index("merchant_id"))
    elif transactions_df is not None:
        # try to harvest basic meta from transactions
        cols = [c for c in ["merchant_id","merchant_name","mcc","mcc_description","state"] if c in transactions_df.columns]
        if cols:
            meta = (transactions_df[cols].dropna()
                    .drop_duplicates(subset=["merchant_id"])
                    .set_index("merchant_id"))

    # ---- assign anomaly cohorts ----
    def pick_subset(pct: float) -> set:
        n = len(merchant_ids)
        k = max(1, int(round(pct * n))) if n > 0 and pct > 0 else 0
        return set(random.sample(merchant_ids, min(k, n))) if k > 0 else set()

    cohort_sanctions  = pick_subset(pct_sanctioned_hit)
    cohort_pep        = pick_subset(pct_pep_owner)
    cohort_adverse    = pick_subset(pct_adverse_media)
    cohort_missdocs   = pick_subset(pct_missing_docs)
    cohort_bad_tax    = pick_subset(pct_incorrect_tax_id)
    cohort_pobox      = pick_subset(pct_po_box_address)
    cohort_highrisk   = pick_subset(pct_high_risk_industry)

    rows = []
    today = date.today()

    for mid in merchant_ids:
        # basic identity / naming
        if meta is not None and mid in meta.index:
            mname = str(meta.loc[mid].get("merchant_name", _rand_business_name(mid))).strip() or _rand_business_name(mid)
            state_from_tx = meta.loc[mid].get("state", None)
            mcc_code, mcc_desc = _choose_mcc(meta.loc[mid])
        else:
            mname = _rand_business_name(mid)
            state_from_tx = None
            mcc_code, mcc_desc = _choose_mcc(None)

        street, city, st, zipc = _rand_address()
        if state_from_tx and state_from_tx in US_STATES:
            st = state_from_tx  # keep consistency with transactions if provided

        # anomalies: PO Box address
        if mid in cohort_pobox:
            street = f"PO Box {random.randint(10, 9999)}"

        # EIN (tax id)
        ein_ok = mid not in cohort_bad_tax
        tax_id = _rand_ein(correct=ein_ok)

        # declared economics (simple priors)
        avg_ticket_decl = float(np.round(np.random.lognormal(mean=np.log(avg_ticket_declared_mu), sigma=avg_ticket_declared_sigma), 2))
        monthly_volume_decl = float(np.round(np.random.lognormal(mean=np.log(monthly_volume_declared_mu), sigma=monthly_volume_declared_sigma), 2))

        # business attributes
        business_type = random.choice(["Sole Proprietor","LLC","Corporation","Partnership","Non-Profit"])
        incorporation_date = today - timedelta(days=random.randint(365*1, 365*25))

        # ownership
        bo_count = random.choice([1,2,3,4])
        bo_pec = sorted(np.random.dirichlet(np.ones(bo_count)), reverse=True)
        bo_top_share = float(np.round(bo_pec[0], 4))

        # policies & contacts
        refund_days = random.choice(refund_policy_days_choices)
        phone = _rand_phone()
        email = _rand_email(mname)
        website = _rand_website(mname)

        # risk & anomalies (flags as bool; you can cast to float if prefieres)
        is_sanctioned_match = (mid in cohort_sanctions)
        pep_beneficial_owner = (mid in cohort_pep)
        adverse_media = (mid in cohort_adverse)
        missing_kyc_docs = (mid in cohort_missdocs)
        po_box_address = (mid in cohort_pobox)
        incorrect_tax_id_format = (mid in cohort_bad_tax)
        is_high_risk_industry = (mid in cohort_highrisk) or (mcc_code in HIGH_RISK_INDUSTRIES)

        # simple risk score (0–100) combining signals
        score = 20.0
        score += 25.0 if is_sanctioned_match else 0.0
        score += 15.0 if pep_beneficial_owner else 0.0
        score += 10.0 if adverse_media else 0.0
        score += 7.0  if missing_kyc_docs else 0.0
        score += 5.0  if incorrect_tax_id_format else 0.0
        score += 5.0  if po_box_address else 0.0
        score += 8.0  if is_high_risk_industry else 0.0
        risk_score = float(np.clip(score + np.random.normal(0, 3), 0, 100))

        rows.append({
            # keys
            "merchant_id": str(mid),

            # legal/commercial identity
            "legal_name": mname.replace(" LLC","").replace(" Inc.",""),
            "business_name": mname,
            "business_type": business_type,
            "incorporation_date": pd.to_datetime(incorporation_date).date(),

            # registration / tax
            "tax_id_ein": tax_id,

            # geography / contact
            "country": "US",
            "state": st,
            "city": city,
            "postal_code": zipc,
            "address_line": street,
            "phone_number": phone,
            "email": email,
            "website": website,

            # merchant category
            "mcc": str(mcc_code),
            "mcc_description": mcc_desc,

            # declared economics
            "average_ticket_declared": avg_ticket_decl,
            "monthly_volume_declared": monthly_volume_decl,

            # policy
            "refund_policy_days": int(refund_days),
            "chargeback_contact_email": f"chargebacks@{re.sub(r'[^a-z0-9]+','', mname.lower())[:12] or 'merchant'}.com",

            # ownership summary
            "beneficial_owners_count": int(bo_count),
            "top_owner_share": bo_top_share,   # 0–1

            # anomaly/risk flags
            "sanctioned_hit": bool(is_sanctioned_match),
            "pep_beneficial_owner": bool(pep_beneficial_owner),
            "adverse_media": bool(adverse_media),
            "missing_kyc_docs": bool(missing_kyc_docs),
            "incorrect_tax_id_format": bool(incorrect_tax_id_format),
            "po_box_address": bool(po_box_address),
            "high_risk_industry": bool(is_high_risk_industry),

            # score
            "risk_score": float(risk_score),
        })

    df_kyc = pd.DataFrame(rows)

    # enforce dtypes
    str_cols = ["merchant_id","legal_name","business_name","business_type","tax_id_ein",
                "country","state","city","postal_code","address_line","phone_number",
                "email","website","mcc","mcc_description","chargeback_contact_email"]
    for c in str_cols:
        df_kyc[c] = df_kyc[c].astype("object")

    df_kyc["incorporation_date"] = pd.to_datetime(df_kyc["incorporation_date"]).dt.date
    df_kyc["refund_policy_days"] = pd.to_numeric(df_kyc["refund_policy_days"], errors="coerce").astype(int)
    df_kyc["beneficial_owners_count"] = pd.to_numeric(df_kyc["beneficial_owners_count"], errors="coerce").astype(int)
    df_kyc["top_owner_share"] = pd.to_numeric(df_kyc["top_owner_share"], errors="coerce").astype(float)
    df_kyc["average_ticket_declared"] = pd.to_numeric(df_kyc["average_ticket_declared"], errors="coerce").astype(float)
    df_kyc["monthly_volume_declared"] = pd.to_numeric(df_kyc["monthly_volume_declared"], errors="coerce").astype(float)
    df_kyc["risk_score"] = (
    pd.to_numeric(df_kyc["risk_score"], errors="coerce")  
    .fillna(0)                                            
    .astype("float64")                               
    .round(2)                                          
)

    bool_cols = ["sanctioned_hit","pep_beneficial_owner","adverse_media","missing_kyc_docs",
                 "incorrect_tax_id_format","po_box_address","high_risk_industry"]
    for c in bool_cols:
        df_kyc[c] = df_kyc[c].astype(bool)

    return df_kyc


In [None]:
# 1) Si ya tienes tx de tu generador:
merchant_ids = df_trx["merchant_id"].dropna().unique().tolist()

df_kyc_anomalos = generate_merchant_kyc(
    merchant_ids=merchant_ids,     
    transactions_df=df_trx,           
    pct_sanctioned_hit=0.01,
    pct_pep_owner=0.02,
    pct_adverse_media=0.06,
    pct_missing_docs=0.10,
    pct_incorrect_tax_id=0.05,
    pct_po_box_address=0.05,
    pct_high_risk_industry=0.08
)

# Write

### 1. Transactions Dataset

In [None]:
dataset_functions.write(
    context,
    context.get_spark_session().createDataFrame(df_trx),
    transactions_dataset().identifier,
    data_environment=DataEnvironment.PUBLIC)

dataset_functions.publish(context, 
                          transactions_dataset().identifier,
                          data_environment=DataEnvironment.PUBLIC)

### 2. Aggregate Features Dataset

In [None]:
dataset_functions.write(context, 
                        context.get_spark_session().createDataFrame(agg_df), 
                        customer_monthly_dataset().identifier,
                        data_environment=DataEnvironment.PUBLIC)



dataset_functions.publish(context, 
                          customer_monthly_dataset().identifier,
                          data_environment=DataEnvironment.PUBLIC)

### 3. Anomalous KYC Dataset

In [None]:
dataset_functions.write(context, 
                        context.get_spark_session().createDataFrame(df_kyc_anomalos),
                        customers_dataset().identifier,
                        data_environment=DataEnvironment.PUBLIC)

dataset_functions.publish(context, 
                          customers_dataset().identifier,
                          data_environment=DataEnvironment.PUBLIC)

In [None]:
context.close()