In [None]:
from datetime import datetime

from pyspark.sql import Window, functions as f

from lib.entity_resolution import EntityResolution, MatchState
from thetaray.api.context import init_context
from thetaray.api.dataset import dataset_functions
from thetaray.api.solution.explainability_validator import validate_explainabilities

context = init_context(execution_date=datetime(1970, 1, 1))

In [None]:
# Read datasets
ds_account = dataset_functions.read(context, "account").drop("name", "address", "country")
ds_card = dataset_functions.read(context, "card")
ds_client = dataset_functions.read(context, "client")
ds_disp_owner = dataset_functions.read(context, "disp_owner")
ds_disp_disponent = dataset_functions.read(context, "disp_disponent")
ds_district = dataset_functions.read(context, "district")
ds_loan = dataset_functions.read(context, "loan")
ds_transaction = dataset_functions.read(context, "transaction", to_job_ts=context.execution_date)
ds_order = dataset_functions.read(context, "order")
ds_country_risk = dataset_functions.read(context, "country_risk")

In [None]:
ds_disp = ds_disp_owner.union(ds_disp_disponent)
features = ds_account.join(ds_disp, "account_id", "outer")
features = features.withColumnRenamed("date", "date_acct").join(
    ds_loan.withColumnRenamed("date", "date_loan"),
    "account_id",
    "left",
)
features = features.withColumnRenamed("district_id", "district_id_bank").join(
    ds_client.withColumnRenamed("district_id", "district_id_client"),
    "client_id",
    "outer",
)
features = features.withColumnRenamed("type", "type_disp").join(
    ds_card.withColumnRenamed("type", "type_card"),
    "disp_id",
    "outer",
).drop("codes")
features = features.withColumnRenamed("date", "date_card")

In [None]:
# print(features.count(), "total feature records, ie one for each client")  # should be 5369

In [None]:
features = features.filter(features.loan_id.isNotNull())
# print(features.count(), "feature records with a loan; some accts repeated due to multiple clients on same acct")  # should be 827
# print(features.select("account_id").distinct().count(), "feature records with a loan and unique account_id")  # should be 682

In [None]:
trans_acctdate = ds_transaction.join(features.select("account_id", "date_loan"), on="account_id")
trans_acctdate = trans_acctdate.withColumn("datediff", f.datediff(f.col("date_loan"), f.col("date")))    
trans_acctdate = trans_acctdate.filter(f.col("datediff") > 0)

windows = []
aggs = []
for x in range(1, 6 + 1):
    trans_acctdate = trans_acctdate.withColumn(f"M{x}", f.when(f.col("datediff") < x * 30, True).otherwise(False))
    w = Window().partitionBy("account_id", f"M{x}").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
    windows.insert(x, w)
    aggs.extend(
        [
            f.min("balance").over(w).alias(f"min{x}"),
            f.max("balance").over(w).alias(f"max{x}"),
            f.mean("balance").over(w).alias(f"mean{x}"),
        ],
    )
    
monbalstats = trans_acctdate.groupBy("account_id", "balance", "M1", "M2", "M3", "M4", "M5", "M6").agg(*aggs)
    
monbalstats = monbalstats.filter(
    (f.col("M1") == True)
    | (f.col("M2") == True)
    | (f.col("M3") == True)
    | (f.col("M4") == True)
    | (f.col("M5") == True)
    | (f.col("M6") == True)
).drop("balance").dropDuplicates()

aggs = []
for x in range(1, 6 + 1):
    monbalstats = monbalstats.withColumn(
        f"min{x}",
        f.when(f.col(f"M{x}") == True, f.col(f"min{x}")).otherwise(f.lit(None)),
    ).withColumn(
        f"max{x}",
        f.when(f.col(f"M{x}") == True, f.col(f"max{x}")).otherwise(f.lit(None)),
    ).withColumn(
        f"mean{x}",
        f.when(f.col(f"M{x}") == True, f.col(f"mean{x}")).otherwise(f.lit(None)),
    )
    aggs.extend(
        [
            f.max(f"min{x}").alias(f"min{x}"),
            f.max(f"max{x}").alias(f"max{x}"),
            f.max(f"mean{x}").alias(f"mean{x}"),
        ],
    )
    
monbalstats = monbalstats.groupBy("account_id").agg(*aggs)

features = features.join(monbalstats, on="account_id", how="left")

In [None]:
# Convert response var `status` = {A,B,C,D} to `response` = {0,1} (AC good, BD bad):
features = features.withColumn(
    "response",
    f.when(f.col("status") == "A", 1).when(f.col("status") == "C", 1).otherwise(0),
).drop("status")

# There are credit card features, but not all clients have cards so these features can be Nan,
# which isn't acceptable in the modeling.  Let's create a `has_card`={0,1} feature, drop the
# date the card was opened, and then below we'll still use the type_card feature in a way
# that avoids NaNs.
features = features.withColumn(
    "has_card",
    f.when(f.col("issued").isNotNull(), 1).otherwise(0),
).drop("issued")

In [None]:
features = features.drop(
    "tr_timestamp",
    "tr_timestamp_client",
    "tr_timestamp_bank",
    "tr_timestamp_y",
    "tr_timestamp_x",
)

In [None]:
features = features.withColumn(
    "has_card",
    f.col("has_card").cast("long"),
).withColumn(
    "response",
    f.col("response").cast("long"),
)

In [None]:
# EXPLAINABILITY (WIDGETS)

features = features.withColumn('population_amount', f.col('amount') * (f.rand() * .25 + .5))
account_name_address = dataset_functions.read(context, 'account').select('account_id', 'name', 'address')
features = features.join(account_name_address, on='account_id', how='left_outer')

In [None]:
# EXPLAINABILITY (WIDGETS) [MVP-51612]

# High Risk Countries
account_country = dataset_functions.read(context, 'account').select('account_id', 'country')

ds_transaction = ds_transaction.join(
    account_country,
    ds_transaction['receiver_id'] == account_country['account_id']
).select(ds_transaction['*'], account_country['country'].alias('receiver_country'))

ds_transaction = ds_transaction.join(
    ds_country_risk,
    ds_transaction['receiver_country'] == ds_country_risk['country_code']
).select(ds_transaction['*'], ds_country_risk['risk'].alias('receiver_country_risk'))

high_risk_countries_explainability = (
    ds_transaction
    .groupby('account_id', 'receiver_country', 'receiver_country_risk')
    .agg(f.count('*').alias('count'), f.sum('amount').alias('sum'))
    .groupby('account_id')
    .agg(f.collect_list(
        f.struct(
            f.col('receiver_country').alias('ct'),
            f.col('receiver_country_risk').alias('cr'),
            f.col('count').alias('c'),
            f.col('sum').alias('s')
        )).alias('high_risk_countries_explainability')
    )
    .select(
        'account_id',
        f.arrays_overlap(
            f.transform(f.col('high_risk_countries_explainability'), lambda x: x['cr']),
            f.array(f.lit('High'), f.lit('Very High'))
        ).alias('high_risk_countries'),  # high_risk_countries is True if the account sent money to high risk countries
        f.to_json(
            f.create_map(
                f.lit('data'),
                f.col('high_risk_countries_explainability')
            )
        ).alias('high_risk_countries_explainability')
    )
)

features = features.join(
    high_risk_countries_explainability.alias('EXPL'),
    on='account_id'
).select(features['*'], 'EXPL.high_risk_countries', 'EXPL.high_risk_countries_explainability')

In [None]:
# EXPLAINABILITY (WIDGETS) [MVP-51612]

# Keyword Matches
keyword_matches_explainability = (
    ds_transaction
    .filter(~f.isnan('keyword_group'))
    .groupby('account_id', 'keyword_group')
    .agg(f.count('*').alias('count'), f.sum('amount').alias('sum'))
    .groupby('account_id')
    .agg(f.collect_list(
        f.struct(
            f.col('keyword_group').alias('kw'),
            f.col('count').alias('c'),
            f.col('sum').alias('s')
        )).alias('keyword_matches_explainability')
    )
    .select(
        'account_id',
        f.lit(True).alias('keyword_matches'),  # keyword_matches is always True here because of isnan filter above
        f.to_json(
            f.create_map(
                f.lit('data'),
                f.col('keyword_matches_explainability')
            )
        ).alias('keyword_matches_explainability')
    )
)

features = features.join(
    keyword_matches_explainability.alias('EXPL'),
    on='account_id',
    how='left'
).select(features['*'], 'EXPL.keyword_matches', 'EXPL.keyword_matches_explainability')

features = features.withColumn('keyword_matches', f.when(f.col('keyword_matches').isNull(), False).otherwise(f.col('keyword_matches')))

In [None]:
features = features.coalesce(4)

In [None]:
validation_df = validate_explainabilities(
    context,
    features,
    'tr_analysis',
    show_all_columns=False # False -> only expl and validation columns, True -> full input df with the validation columns
)

## Load parties from graph and join it to aggregated features

### Load graph

In [None]:
grouping_identifier = "party_id"
primary_identifier = "account_id"
primary_entities_list = "party_accounts"

er = EntityResolution(context=context, graph_id="public")
graph_account_party = er.load_graph(
    grouping_identifier=grouping_identifier,
    primary_identifier=primary_identifier,
    states=[MatchState.AUTO_CONFIRM, MatchState.MANUAL_CONFIRM],
)
graph_account_party = graph_account_party.select(grouping_identifier, primary_identifier)
graph_account_party.limit(5).toPandas()

### Aggregate all accounts into primary_entities_list column

In [None]:
map_column = f.create_map(f.lit(primary_identifier), primary_identifier)

# Aggregate the DataFrame to collect list of map for each primary_identifier
mapped_df = (
    graph_account_party.withColumn(primary_entities_list, map_column)
    .groupBy(grouping_identifier)
    .agg(f.collect_list(primary_entities_list).alias(primary_entities_list))
)

# Convert the primary_entities_list column to a JSON string
mapped_df = mapped_df.withColumn(primary_entities_list, f.to_json(primary_entities_list))

# Join the aggregated df
graph_account_party = graph_account_party.join(f.broadcast(mapped_df), grouping_identifier)

# Put space after colon in order to be compatible with trace queries
graph_account_party = graph_account_party.withColumn(primary_entities_list, f.regexp_replace(primary_entities_list, ":", ": "))

graph_account_party.limit(5).toPandas()

In [None]:
features = features.join(
    f.broadcast(graph_account_party),
    on=primary_identifier,
    how="left",
).drop(graph_account_party[primary_identifier])

In [None]:
dataset_functions.write(context, features, "wrangling")

In [None]:
context.close()