In [None]:
import datetime
from dateutil.relativedelta import relativedelta
import json
import logging
from pyspark.sql import DataFrame, Window, functions as f
from pyspark.sql import SQLContext
from pyspark.sql.types import LongType
import yaml

from common.libs import dates as dates_lib
from common.libs import features_discovery
from common.libs.features_executor import FeaturesExecutor
from common.libs.feature_engineering import max_look_back_monthly_features, max_look_back_daily_weekly_features
from common.libs.zscore import enrich_with_z_score
from common.factory.wrangling_execution_strategy import get_wrangling_execution_strategy
from common.factory.eval_flow_definition import get_evaluation_flow_definition
from common.factory.domain_definition import get_domain_definition
from common.notebook_utils.wrangling.wrangling_execution_strategy import WranglingExecutionStrategy
from common.definitions.domain import DomainDefinition
from common.definitions.eval_flow import EvaluationFlowDefinition
from common.libs.context_utils import get_dataset

from thetaray.api.context import init_context
from thetaray.api.dataset import dataset_functions
from thetaray.api.solution import IngestionMode
from thetaray.common import Constants
from thetaray.common.data_environment import DataEnvironment

logging.getLogger().handlers[0].setFormatter(logging.Formatter(fmt='%(levelname)s: %(asctime)s @ %(message)s',datefmt='%Y-%m-%d %H:%M:%S'))
logging.basicConfig(level=logging.INFO)

import pandas as pd
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)


from thetaray.api.context import init_context
import datetime
from thetaray.common import Constants

from common.libs.config.loader import load_config
from common.libs.config.basic_execution_config_loader import BasicExecutionConfig, DevBasicExecutionConfig
from common.libs.context_utils import is_run_triggered_from_airflow



with open('/thetaray/git/solutions/domains/demo_fuib/config/spark_config.yaml') as spark_config_file:
    spark_config = yaml.load(spark_config_file, yaml.FullLoader)['spark_config_a']

execution_date=datetime.datetime(1970, 1, 1)

context = init_context(domain='demo_fuib',
                       execution_date=execution_date,
                       spark_conf=spark_config,
                       spark_master='local[*]')

spark = context.get_spark_session()
sc = SQLContext(spark)
params = context.parameters
print(f"Spark UI URL: {context.get_spark_ui_url()}")

print(json.dumps(params, indent=4))

In [None]:
# from domains.demo_fuib.libs.transactions_generator_spark import generate_transactions_spark

# df = generate_transactions_spark(
#     spark,
#     start_date="2023-07-01",
#     end_date="2025-06-30",
#     n_customers=1200,
#     pct_anomalous_monthly=0.1,
#     avg_txn_per_customer_per_month=(5, 60),
#     seed=1337
# )
# df.count()

In [None]:
from pyspark.sql import functions as F, Window as W

def build_train_and_pred_base(
    spark,
    *,
    n_customers: int = 1500,
    seed: int = 1337,
    max_train_rows: int = 1_000_000,
):
    # 1) Generate once through June 30 (so the same customers exist in June)
    from domains.demo_fuib.libs.transactions_generator_spark import generate_transactions_spark

    df_all = generate_transactions_spark(
        spark,
        start_date="2023-07-01",
        end_date="2025-06-30",
        n_customers=n_customers,
        pct_anomalous_monthly=0.00,                 # keep base anomalies very low (or 0.0)
        avg_txn_per_customer_per_month=(5, 60),
        seed=seed
    ).cache()

    # 2) Split into train (≤ 2025-05-31) and pred_base (2025-06)
    train = df_all.where(F.col("txn_ts") < F.lit("2025-06-01"))
    pred_base = df_all.where((F.col("txn_ts") >= F.lit("2025-06-01")) & (F.col("txn_ts") < F.lit("2025-07-01")))

    # 3) Cap train rows to ≤ 1M with per-customer proportional sampling (no heavy skew)
    cnt = train.count()
    if cnt > max_train_rows:
        # Compute per-customer quotas proportional to each customer's share in train
        cust_sizes = train.groupBy("customer_id").count().withColumnRenamed("count", "cust_cnt")
        total = cust_sizes.agg(F.sum("cust_cnt")).first()[0]
        quota = cust_sizes.withColumn(
            "keep", (F.col("cust_cnt") / F.lit(total) * F.lit(max_train_rows)).cast("int")
        )

        # Assign a stable random order per customer and keep up to 'keep'
        w = W.partitionBy("customer_id").orderBy(F.rand(seed))
        train_ranked = train.withColumn("rn", F.row_number().over(w))

        train = (
            train_ranked.join(quota, "customer_id", "left")
                        .where(F.col("rn") <= F.greatest(F.col("keep"), F.lit(1)))
                        .drop("rn", "cust_cnt", "keep")
        )

        # (Optional) exact trim if we were off by a few rows:
        over = train.count() - max_train_rows
        if over > 0:
            train = train.orderBy(F.rand(seed + 1)).limit(max_train_rows)

    return train.cache(), pred_base.cache()

In [None]:
from typing import Dict, Any, Tuple, Optional, List
from datetime import datetime, timedelta
import random, string, numpy as np
from pyspark.sql import DataFrame, Row, SparkSession
from pyspark.sql import functions as F

HIGH_RISK_DEST = {"TR", "CY", "AE"}
FX_RATES_TO_USD = {"UAH": 0.027, "USD": 1.00, "EUR": 1.10}
CURRENCIES = ["UAH", "USD", "EUR"]

def _rand_id(prefix: str, n: int = 10) -> str:
    s = "".join(random.choices(string.ascii_uppercase + string.digits, k=n))
    return f"{prefix}_{s}"

def _usd(amount: float, currency: str) -> float:
    return float(round(amount * FX_RATES_TO_USD.get(currency, 1.0), 2))

def _pick_currency_for_usd_target(usd_target: float) -> Tuple[str, float]:
    c = random.choices(CURRENCIES, weights=[0.2,0.7,0.1], k=1)[0]
    if c == "USD": return "USD", float(round(usd_target,2))
    return c, float(round(usd_target / FX_RATES_TO_USD[c], 2))

def _sample_times(start_dt: datetime, end_dt: datetime, n: int) -> List[datetime]:
    span = int((end_dt - start_dt).total_seconds())
    offs = sorted(np.random.randint(0, max(1, span+1), size=n).tolist())
    return [start_dt + timedelta(seconds=int(o)) for o in offs]

def _pick_3day_window(last_start: datetime, last_end: datetime) -> Tuple[datetime, datetime]:
    latest_start = last_end - timedelta(days=3)
    if latest_start <= last_start: s = last_start
    else:
        span = int((latest_start - last_start).total_seconds())
        s = last_start + timedelta(seconds=int(np.random.randint(0, span+1)))
    return s, s + timedelta(days=3)

def _build_row(c: Dict[str, Any], ts: datetime, *, method: str, direction: str,
               usd_amount: float, purpose: str = "services",
               cp_id: Optional[str] = None, cp_jur: Optional[str] = None,
               cross_border: bool = False, known_shell: bool = False,
               same_cp: bool = False) -> Row:
    currency, amount = _pick_currency_for_usd_target(usd_amount)
    return Row(**{
        "txn_id": _rand_id("T", 12),
        "txn_ts": ts,
        "customer_id": c["customer_id"],
        "account_id": c["account_id"],
        # KYC snapshot
        "customer_type": c.get("customer_type"),
        "customer_name": c.get("customer_name"),
        "date_of_birth": c.get("date_of_birth"),
        "address_full": c.get("address_full"),
        "occupation": c.get("occupation"),
        "pep_indicator": bool(c.get("pep_indicator", False)),
        "industry_code": c.get("industry_code"),
        "account_open_months": None,
        "whitelist_flag": bool(c.get("whitelist_flag", False)),
        # Txn
        "direction": direction,
        "method": method,
        "amount": float(amount),
        "currency": currency,
        "amount_usd": float(_usd(amount, currency)),
        "purpose_code": purpose,
        "cross_border_flag": bool(cross_border),
        "declared_counterparty_name": None,
        "balance_after_txn": 0.0,
        "is_split_component": False,
        # Counterparty
        "cp_id": (cp_id if cp_id else _rand_id("CP", 8)) if not same_cp else "CP_ANCHOR",
        "cp_type": "external_legal_entity",
        "cp_jurisdiction": cp_jur,
        "cp_known_shell_flag": bool(known_shell),
        "cp_is_bank_customer": False,
        "counterparty_name": None,
        "counterparty_address": None,
        "counterparty_ip": None,
        # Endpoint
        "endpoint_type": "online_banking",
        "ip_hash": f"IP_{_rand_id('',6)[1:]}",
        "device_id_hash": f"DEV_{_rand_id('',6)[1:]}",
        "geo_lat": None, "geo_lon": None, "atm_id": None,
        "withdrawal_channel": None, "pos_terminal_id": None,
        # Linkage/annotations
        "rep_id_hash": c.get("rep_id_hash"),
        "address_hash": c.get("address_hash"),
        "website_hash": c.get("website_hash"),
        "scenario_type": "injected_last_month_capped",
        "scenario_role": "injector",
        "anomaly_flag": True,
    })

def inject_last_month_anomalies_proportional_capped(
    spark: SparkSession,
    *,
    train_df: DataFrame,       # up to 2025-05-31
    pred_base_df: DataFrame,   # June 2025 baseline
    frac_customers: Dict[str, float],
    multipliers: Dict[str, float],
    max_injected_share_of_june: float = 0.10,  # ≤ 10% extra rows vs June baseline
    seed: int = 1337
) -> DataFrame:

    random.seed(seed); np.random.seed(seed)

    # Bounds for June
    last_start = datetime(2025,6,1,0,0,0)
    last_end   = datetime(2025,6,30,23,59,59)

    # Ensure continuity: consider customers that exist in TRAIN and in June baseline
    train_customers = train_df.select("customer_id","account_id","customer_type","customer_name",
                                      "date_of_birth","address_full","occupation","pep_indicator",
                                      "industry_code","whitelist_flag","rep_id_hash","address_hash","website_hash") \
                              .dropDuplicates(["customer_id","account_id"])
    june_customers  = pred_base_df.select("customer_id","account_id").dropDuplicates(["customer_id","account_id"])
    active = june_customers.join(train_customers, ["customer_id","account_id"], "inner")
    n_active = active.count()
    if n_active == 0:
        return pred_base_df  # nothing to inject

    # Baselines from previous 3 months (Mar/Apr/May 2025)
    prev3 = train_df.where((F.col("txn_ts") >= F.lit("2025-03-01")) & (F.col("txn_ts") < F.lit("2025-06-01")))
    prev3_agg = prev3.groupBy("customer_id").agg(
        F.sum(F.abs("amount_usd")).alias("base_turnover_usd"),
        F.count("*").alias("base_txn_count"),
        F.max(F.abs("amount_usd")).alias("base_max_single_usd"),
        F.countDistinct("cp_id").alias("base_distinct_cp"),
        F.sum(F.when(F.col("cp_jurisdiction").isin(list(HIGH_RISK_DEST)),
                     F.abs(F.col("amount_usd"))).otherwise(0.0)).alias("base_turnover_hr_usd")
    )

    june_agg = pred_base_df.groupBy("customer_id").agg(
        F.sum(F.abs("amount_usd")).alias("cur_turnover_usd"),
        F.count("*").alias("cur_txn_count"),
        F.max(F.abs("amount_usd")).alias("cur_max_single_usd"),
        F.countDistinct("cp_id").alias("cur_distinct_cp"),
        F.sum(F.when(F.col("cp_jurisdiction").isin(list(HIGH_RISK_DEST)),
                     F.abs(F.col("amount_usd"))).otherwise(0.0)).alias("cur_turnover_hr_usd")
    )

    stats = (active.join(june_agg, "customer_id", "left")
                   .join(prev3_agg, "customer_id", "left")
                   .fillna({"cur_turnover_usd":0.0,"cur_txn_count":0,"cur_max_single_usd":0.0,"cur_distinct_cp":0,
                            "base_turnover_usd":0.0,"base_txn_count":0,"base_max_single_usd":0.0,"base_distinct_cp":0,
                            "cur_turnover_hr_usd":0.0,"base_turnover_hr_usd":0.0}))

    candidates = [r.asDict(recursive=True) for r in stats.collect()]
    random.shuffle(candidates)

    def pick_fraction(tag: str) -> List[Dict[str, Any]]:
        k = max(1, int(float(frac_customers.get(tag, 0.0)) * n_active))
        return random.sample(candidates, k=min(k, len(candidates)))

    new_rows: List[Row] = []

    # ---------- Patterns (inject deltas to reach targets) ----------
    # 1) High monthly transactional value
    for c in pick_fraction("high_monthly_value"):
        base = float(c["base_turnover_usd"])/max(1,3)
        target = max(30_000.0, multipliers.get("turnover",3.0)*base)
        delta = max(0.0, target - float(c["cur_turnover_usd"]))
        if delta <= 0: continue
        n = int(np.random.randint(5,9))
        parts = (np.random.dirichlet(np.ones(n)) * delta).tolist()
        times = _sample_times(last_start, last_end, n)
        for a,t in zip(parts,times):
            new_rows.append(_build_row(c,t,method="wire_out",direction="debit",usd_amount=float(a),purpose="invoice"))

    # 2) High monthly value to high-risk jurisdictions
    for c in pick_fraction("high_monthly_value_hr"):
        base = float(c["base_turnover_hr_usd"])/max(1,3)
        target = max(20_000.0, multipliers.get("turnover_hr",2.5)*base)
        delta = max(0.0, target - float(c["cur_turnover_hr_usd"]))
        if delta <= 0: continue
        n = int(np.random.randint(4,8))
        parts = (np.random.dirichlet(np.ones(n)) * delta).tolist()
        times = _sample_times(last_start, last_end, n)
        for a,t in zip(parts,times):
            new_rows.append(_build_row(c,t,method="wire_out",direction="debit",usd_amount=float(a),
                                       purpose="services", cp_jur=random.choice(list(HIGH_RISK_DEST)),
                                       cross_border=True, known_shell=(random.random()<0.3)))

    # 3) High monthly transactional volume
    for c in pick_fraction("high_monthly_volume"):
        base_cnt = float(c["base_txn_count"])/max(1,3)
        target_cnt = int(max(30, multipliers.get("volume",3.0)*base_cnt))
        need = max(0, target_cnt - int(c["cur_txn_count"]))
        if need <= 0: continue
        times = _sample_times(last_start, last_end, need)
        for t in times:
            amt = float(np.random.lognormal(mean=3.0, sigma=0.4))  # small e-com
            new_rows.append(_build_row(c,t,method="card_ecom",direction="debit",usd_amount=amt,purpose="small_purchase"))

    # 4) High single transactional value
    for c in pick_fraction("high_single_txn"):
        base_max = float(c["base_max_single_usd"])
        target = max(20_000.0, multipliers.get("single_txn",3.0)*base_max)
        if float(c["cur_max_single_usd"]) >= target: continue
        t = _sample_times(last_start, last_end, 1)[0]
        new_rows.append(_build_row(c,t,method="wire_out",direction="debit",usd_amount=target,purpose="one_off_capital"))

    # 5) High transactional value in a 3-day window
    for c in pick_fraction("high_value_3d"):
        base = float(c["base_turnover_usd"])/max(1,3)
        target_cluster = max(25_000.0, multipliers.get("value_3d",2.5)*base)
        s,e = _pick_3day_window(last_start,last_end)
        n = int(np.random.randint(5,9))
        parts = (np.random.dirichlet(np.ones(n)) * target_cluster).tolist()
        times = _sample_times(s,e,n)
        for a,t in zip(parts,times):
            new_rows.append(_build_row(c,t,method="wire_out",direction="debit",usd_amount=float(a),purpose="clustered_payouts"))

    # 6) High number of counterparties
    for c in pick_fraction("many_counterparties"):
        base_k = float(c["base_distinct_cp"])/max(1,3)
        target_k = int(max(10, multipliers.get("distinct_cp",3.0)*base_k))
        need_k = max(0, target_k - int(c["cur_distinct_cp"]))
        if need_k <= 0: continue
        times = _sample_times(last_start, last_end, need_k)
        for t in times:
            amt = float(np.random.lognormal(9.5,0.4)/1000)
            new_rows.append(_build_row(c,t,method="wire_out",direction="debit",
                                       usd_amount=amt,purpose="supplier",cp_id=_rand_id("CP",8)))

    # 7) High concentration to a single counterparty
    for c in pick_fraction("concentration_one_cp"):
        target_share = float(multipliers.get("concentration_share",0.85))
        base = float(c["base_turnover_usd"])/max(1,3)
        anchor = _rand_id("CP",8)
        n = int(np.random.randint(25,45))
        chunk = max(20_000.0, 0.8*target_share*base)
        times = _sample_times(last_start,last_end,n)
        for t in times:
            usd_amt = (chunk / n) * float(np.random.uniform(0.8,1.2))
            new_rows.append(_build_row(c,t,method="wire_out",direction="debit",
                                       usd_amount=usd_amt,purpose="services",
                                       cp_id=anchor, same_cp=True))

    # 8) Similar high credit and debit in a 3-day window (fast in/out)
    for c in pick_fraction("fast_in_fast_out"):
        base = float(c["base_turnover_usd"])/max(1,3)
        s,e = _pick_3day_window(last_start,last_end)
        t_in = _sample_times(s, e - timedelta(hours=24), 1)[0]
        base_in = max(25_000.0, 0.6*base)
        # credit in
        new_rows.append(_build_row(c,t_in,method="wire_in",direction="credit",
                                   usd_amount=base_in,purpose="unexpected_inflow"))
        # one or two debits totaling ≈ credit (±5%)
        k = int(np.random.randint(1,3))
        out_total = base_in * float(np.random.uniform(0.95,1.05))
        weights = np.random.dirichlet(np.ones(k))
        for w in weights:
            t_out = t_in + timedelta(hours=int(np.random.randint(6,48)))
            new_rows.append(_build_row(c,t_out,method="wire_out",direction="debit",
                                       usd_amount=float(out_total*w),purpose="rapid_outflow"))

    # ---------- Hard cap on injected rows ----------
    june_rows = pred_base_df.count()
    cap = int(june_rows * max_injected_share_of_june)  # e.g., ≤10% extra rows
    if len(new_rows) > cap:
        random.shuffle(new_rows)
        new_rows = new_rows[:cap]

    if not new_rows:
        return pred_base_df

    injected_df = spark.createDataFrame(new_rows, schema=pred_base_df.schema)
    return pred_base_df.unionByName(injected_df)


In [None]:
# Step 1: build train (≤1M) and June baseline
train, pred_base = build_train_and_pred_base(spark, n_customers=1500, seed=1337, max_train_rows=800000)

print("TRAIN rows:", train.count())
print("JUNE baseline rows:", pred_base.count())

# Step 2: inject strong anomalies into June, capped
frac_customers = {
    "high_monthly_value": 0.05,
    "high_monthly_value_hr": 0.05,
    "high_monthly_volume": 0.05,
    "high_single_txn": 0.05,
    "high_value_3d": 0.05,
    "many_counterparties": 0.05,
    "concentration_one_cp": 0.05,
    "fast_in_fast_out": 0.05,
}
multipliers = {
    "turnover": 5.0,
    "turnover_hr": 4.5,
    "volume": 4.0,
    "single_txn": 4.0,
    "value_3d": 4.0,
    "distinct_cp": 4.0,
    "concentration_share": 0.95,
}

pred = inject_last_month_anomalies_proportional_capped(
    spark,
    train_df=train,
    pred_base_df=pred_base,
    frac_customers=frac_customers,
    multipliers=multipliers,
    max_injected_share_of_june=0.25,   # ≤ 10% more rows than June baseline
    seed=1337
)

print("PRED rows (after injection):", pred.count())


In [None]:
# from pyspark.sql import functions as F

# # Add a source label
# train_labeled = train.withColumn("dataset_type", F.lit("train"))
# pred_labeled  = pred.withColumn("dataset_type", F.lit("pred"))

# # Combine into one DataFrame
# df = train_labeled.unionByName(pred_labeled)

# # Optional: cache if you'll query it a lot
# df.cache()

# print("Combined count:", df.count())
# df.groupBy("dataset_type").count().show()

In [None]:
from pyspark.sql import Row, functions as F
from datetime import datetime, timedelta, date
import numpy as np
import random, string

# --- Config ---
TOTAL_IN_UAH = 450_000
OUT_UAH      = 420_000
FX_UAH_TO_USD = 0.027  # keep consistent with your generator
DOB_2007 = date(2007, 4, 15)   # any date in 2007 works
THREE_DAY_START = datetime(2025, 6, 10, 10, 0, 0)  # anchor for the 3-day window (change if you like)

def _rand_id(prefix: str, n: int = 10) -> str:
    s = "".join(random.choices(string.ascii_uppercase + string.digits, k=n))
    return f"{prefix}_{s}"

def _sample_times_3d(start_dt: datetime, n: int):
    end_dt = start_dt + timedelta(days=3)
    span = int((end_dt - start_dt).total_seconds())
    offs = sorted(np.random.randint(0, max(1, span+1), size=n).tolist())
    return [start_dt + timedelta(seconds=int(o)) for o in offs]

def inject_student_case(pred_df, spark, target_customer_id=None, seed=2025):
    random.seed(seed); np.random.seed(seed)

    schema = pred_df.schema

    # 1) Find a June-active individual customer if not provided
    if target_customer_id is None:
        june_inds = (
            pred_df
            .where((F.col("txn_ts") >= F.lit("2025-06-01")) & (F.col("txn_ts") < F.lit("2025-07-01")))
            .where(F.col("customer_type") == "individual")
            .select("customer_id", "account_id", "customer_type", "customer_name",
                    "date_of_birth", "address_full", "occupation", "pep_indicator",
                    "industry_code", "whitelist_flag", "rep_id_hash", "address_hash", "website_hash")
            .dropDuplicates(["customer_id","account_id"])
            .limit(1)
            .collect()
        )
        if not june_inds:
            # fallback: any customer
            june_inds = (
                pred_df.select("customer_id","account_id","customer_type","customer_name",
                               "date_of_birth","address_full","occupation","pep_indicator",
                               "industry_code","whitelist_flag","rep_id_hash","address_hash","website_hash")
                       .dropDuplicates(["customer_id","account_id"])
                       .limit(1)
                       .collect()
            )
        base = june_inds[0].asDict(recursive=True)
    else:
        # Pull a snapshot row for this customer from pred (June)
        base = (
            pred_df.where(F.col("customer_id")==F.lit(target_customer_id))
                   .select("customer_id","account_id","customer_type","customer_name",
                           "date_of_birth","address_full","occupation","pep_indicator",
                           "industry_code","whitelist_flag","rep_id_hash","address_hash","website_hash")
                   .limit(1).collect()
        )
        if not base:
            raise ValueError(f"customer_id {target_customer_id} not found in pred_df")
        base = base[0].asDict(recursive=True)

    # Force KYC snapshot for this case
    base["customer_type"] = "individual"
    base["occupation"] = "student"
    base["date_of_birth"] = DOB_2007  # string is OK; Spark will cast to DateType if schema expects it

    # 2) Build 15 inbound wires ~450k UAH total
    n_in = 15
    times_in = _sample_times_3d(THREE_DAY_START, n_in)
    # Split the total into 15 positive parts with small variation
    parts = np.random.dirichlet(np.ones(n_in)) * TOTAL_IN_UAH
    amounts_in_uah = [float(round(a, 2)) for a in parts]  # keep decimals if you want; can cast to int

    rows = []

    for i, (ts, amt_uah) in enumerate(zip(times_in, amounts_in_uah)):
        rows.append(Row(**{
            "txn_id": _rand_id("T", 12),
            "txn_ts": ts,
            "customer_id": base["customer_id"],
            "account_id": base["account_id"],
            # KYC snapshot (per-row in your schema)
            "customer_type": base.get("customer_type"),
            "customer_name": base.get("customer_name"),
            "date_of_birth": base.get("date_of_birth"),
            "address_full": base.get("address_full"),
            "occupation": base.get("occupation"),
            "pep_indicator": bool(base.get("pep_indicator", False)),
            "industry_code": base.get("industry_code"),
            "account_open_months": None,
            "whitelist_flag": bool(base.get("whitelist_flag", False)),
            # Txn
            "direction": "credit",
            "method": "wire_in",
            "amount": amt_uah,
            "currency": "UAH",
            "amount_usd": round(amt_uah * FX_UAH_TO_USD, 2),
            "purpose_code": "tuition_support",
            "cross_border_flag": False,
            "declared_counterparty_name": None,
            "balance_after_txn": 0.0,
            "is_split_component": False,
            # Counterparty (distinct cp_id each time)
            "cp_id": _rand_id("CP", 8),
            "cp_type": "external_individual",
            "cp_jurisdiction": "UA",
            "cp_known_shell_flag": False,
            "cp_is_bank_customer": False,
            "counterparty_name": None,
            "counterparty_address": None,
            "counterparty_ip": None,
            # Endpoint (online banking is fine)
            "endpoint_type": "online_banking",
            "ip_hash": f"IP_{_rand_id('',6)[1:]}",
            "device_id_hash": f"DEV_{_rand_id('',6)[1:]}",
            "geo_lat": None, "geo_lon": None, "atm_id": None,
            "withdrawal_channel": None, "pos_terminal_id": None,
            # Linkage
            "rep_id_hash": base.get("rep_id_hash"),
            "address_hash": base.get("address_hash"),
            "website_hash": base.get("website_hash"),
            # Annotations
            "scenario_type": "case_student_cluster",
            "scenario_role": "subject",
            "anomaly_flag": True,
        }))

    # 3) One outbound wire to Turkey ~420k UAH within the 3-day window (after last inbound)
    t_out = times_in[-1] + timedelta(hours=int(np.random.randint(4, 24)))
    rows.append(Row(**{
        "txn_id": _rand_id("T", 12),
        "txn_ts": t_out,
        "customer_id": base["customer_id"],
        "account_id": base["account_id"],
        # KYC snapshot
        "customer_type": base.get("customer_type"),
        "customer_name": base.get("customer_name"),
        "date_of_birth": base.get("date_of_birth"),
        "address_full": base.get("address_full"),
        "occupation": base.get("occupation"),
        "pep_indicator": bool(base.get("pep_indicator", False)),
        "industry_code": base.get("industry_code"),
        "account_open_months": None,
        "whitelist_flag": bool(base.get("whitelist_flag", False)),
        # Txn
        "direction": "debit",
        "method": "wire_out",
        "amount": float(OUT_UAH),
        "currency": "UAH",
        "amount_usd": round(OUT_UAH * FX_UAH_TO_USD, 2),
        "purpose_code": "study_expense",
        "cross_border_flag": True,
        "declared_counterparty_name": "Overseas Tuition",
        "balance_after_txn": 0.0,
        "is_split_component": False,
        # Counterparty (Turkey = high risk in your mapping)
        "cp_id": _rand_id("CP", 8),
        "cp_type": "foreign_bank",
        "cp_jurisdiction": "TR",
        "cp_known_shell_flag": False,
        "cp_is_bank_customer": False,
        "counterparty_name": "TR_Beneficiary",
        "counterparty_address": None,
        "counterparty_ip": None,
        # Endpoint
        "endpoint_type": "online_banking",
        "ip_hash": f"IP_{_rand_id('',6)[1:]}",
        "device_id_hash": f"DEV_{_rand_id('',6)[1:]}",
        "geo_lat": None, "geo_lon": None, "atm_id": None,
        "withdrawal_channel": None, "pos_terminal_id": None,
        # Linkage
        "rep_id_hash": base.get("rep_id_hash"),
        "address_hash": base.get("address_hash"),
        "website_hash": base.get("website_hash"),
        # Annotations
        "scenario_type": "case_student_cluster",
        "scenario_role": "subject",
        "anomaly_flag": True,
    }))

    injected = spark.createDataFrame(rows, schema=schema)
    # injected  = injected.withColumn("dataset_type", F.lit("pred"))
    
    return pred_df.unionByName(injected), base["customer_id"]

# --------- Run it ---------
pred_with_case, student_customer_id = inject_student_case(pred, spark)
print("Injected student case for customer:", student_customer_id)
print(pred_with_case.count())

# Quick verification
from pyspark.sql import functions as F
pred_with_case.where(
    (F.col("customer_id")==student_customer_id) &
    (F.col("scenario_type")=="case_student_cluster")
).select(
    "txn_ts","direction","method","currency","amount","amount_usd",
    "cp_jurisdiction","purpose_code"
).orderBy("txn_ts").show(50, truncate=False)

In [None]:
from pyspark.sql import functions as F

# Add a source label
train_labeled = train.withColumn("dataset_type", F.lit("train"))
pred_labeled  = pred_with_case.withColumn("dataset_type", F.lit("pred"))

# Combine into one DataFrame
df = train_labeled.unionByName(pred_labeled)

# Optional: cache if you'll query it a lot
df.cache()

print("Combined count:", df.count())
df.groupBy("dataset_type").count().show()

In [None]:
from pyspark.sql import functions as F

# Define the mapping dictionary
COUNTRY_RISK = {
    "UA": "Medium",
    "PL": "Low",
    "TR": "High",
    "CY": "High",
    "AE": "Medium",
    "GB": "Low",
    "DE": "Low",
    "US": "Low",
    "LT": "Low"
}

# Add the cp_country_risk_level column
df = df.withColumn(
    "cp_country_risk_level",
    f.create_map([f.lit(x) for kv in COUNTRY_RISK.items() for x in kv])
      .getItem(F.col("cp_jurisdiction"))
)


In [None]:
df = df.withColumn(
    "is_new_account",
    f.when(f.col('account_open_months')<=6, f.lit(True)).otherwise(f.lit(False))
)

In [None]:
from common.libs import dates as dates_lib

trx_date_column_name = "txn_ts"
df = dates_lib.add_day_offset_column(df, trx_date_column_name, 'day_offset')
df = dates_lib.add_month_offset_column(df, trx_date_column_name, 'month_offset')
df = dates_lib.month_offset_to_year_month_columns(df, 'month_offset', 'year_month')

In [None]:
from domains.common.libs.tr_levenshtein import get_lev_ind
#Me to me
threshold = 1
df = df.withColumn("is_me_to_me", get_lev_ind('customer_name', 'counterparty_name', threshold))
print("Me-to-me field added")

In [None]:
df.count()

In [None]:
# from thetaray.utils.schema_ds_ffc_generator import create_metadata_ds_file_from_df, create_metadata_ds_file_from_csv
# from thetaray.api.solution import DataSet, Field, DataType, IngestionMode

# create_metadata_ds_file_from_df(context=context,
#                                 df=df,
#                                 ds_identifier="trx_enriched",
#                                 ds_display_name="trx_enriched",
#                                 ingestion_mode=IngestionMode.APPEND,
#                                 publish=True,
#                                 primary_key=['txn_id'],
#                                 occurred_on_field="txn_ts",
#                                 data_permission="dpv:demo_fuib",
#                                 num_of_partitions=1,
#                                 num_of_buckets=1)

In [None]:
from thetaray.common.data_environment import DataEnvironment
dataset_functions.write(context, df, "trx_enriched", data_environment=DataEnvironment.PUBLIC)