In [None]:
from datetime import datetime

from pyspark.sql import Window, functions as f

from lib.entity_resolution import EntityResolution, MatchState
from thetaray.api.context import init_context
from thetaray.api.dataset import dataset_functions

context = init_context(execution_date=datetime(1970, 1, 1))

In [None]:
# Read datasets
ds_account = dataset_functions.read(context, "account").drop("name", "address", "country")
ds_card = dataset_functions.read(context, "card")
ds_client = dataset_functions.read(context, "client")
ds_disp_owner = dataset_functions.read(context, "disp_owner")
ds_disp_disponent = dataset_functions.read(context, "disp_disponent")
ds_district = dataset_functions.read(context, "district")
ds_loan = dataset_functions.read(context, "loan")
ds_transaction = dataset_functions.read(context, "transaction", to_job_ts=context.execution_date)
ds_order = dataset_functions.read(context, "order")

### Load graph

In [None]:
er = EntityResolution(context=context, graph_id="public")
graph_account_party = er.load_graph(
    grouping_identifier="party_id",
    primary_identifier="account_id",
    states=[MatchState.AUTO_CONFIRM, MatchState.MANUAL_CONFIRM],
)
graph_account_party = graph_account_party.select("party_id", "account_id")
# graph_account_party.limit(5).toPandas()

### Aggregate all accounts into primary_entities_list column

In [None]:

map_column = f.create_map(f.lit("account_id"), "account_id")

# Aggregate the DataFrame to collect list of map for each primary_identifier
mapped_df = (
    graph_account_party.withColumn("party_accounts", map_column)
    .groupBy("party_id")
    .agg(f.collect_list("party_accounts").alias("party_accounts"))
)

# Convert the primary_entities_list column to a JSON string
mapped_df = mapped_df.withColumn("party_accounts", f.to_json("party_accounts"))

# Join the aggregated df
graph_account_party = graph_account_party.join(f.broadcast(mapped_df), "party_id")

# Put space after colon in order to be compatible with trace queries
graph_account_party = graph_account_party.withColumn("party_accounts", f.regexp_replace("party_accounts", ":", ": "))

graph_account_party.limit(5).toPandas()

### Inner Join parties to accounts, so now we works only with accounts that has parties

In [None]:
graph_account_party = graph_account_party.join(ds_account, on="account_id", how="inner").drop(ds_account["account_id"])
# graph_account_party.limit(5).toPandas()

In [None]:
# graph_account_party.count() # 450

In [None]:
ds_disp = ds_disp_owner.union(ds_disp_disponent)
features = graph_account_party.join(ds_disp, "account_id", "outer")

features = features.withColumnRenamed("date", "date_acct")
ds_loan = ds_loan.withColumnRenamed("date", "date_loan")
features = features.join(ds_loan, "account_id", "left")

features = features.withColumnRenamed("district_id", "district_id_bank")
ds_client = ds_client.withColumnRenamed("district_id", "district_id_client")
features = features.join(ds_client, "client_id", "outer")

ds_card = ds_card.withColumnRenamed("type", "type_card")
features = features.withColumnRenamed("type", "type_disp")
features = features.join(ds_card, "disp_id", "outer").drop("codes")

features = features.where(features.party_id.isNotNull() & features.loan_id.isNotNull())

In [None]:
# print(features.count(), "feature records with a loan; some accts repeated due to multiple clients on same acct")  # should be 79
# print(features.select("party_id").distinct().count(), "feature records with a loan and unique party_id")  # should be 61

In [None]:
trans_acctdate = ds_transaction.join(features.select("account_id", "date_loan", "party_id"), on="account_id", how="inner")
trans_acctdate = trans_acctdate.withColumn("datediff", f.datediff(f.col("date_loan"), f.col("date")))    
trans_acctdate = trans_acctdate.filter((f.col("datediff") > 0))

In [None]:
# trans_acctdate.count() # 5734

In [None]:
windows = []
aggs = []
for x in range(1, 6 + 1):
    trans_acctdate = trans_acctdate.withColumn(f"M{x}", f.when(f.col("datediff") < x * 30, True).otherwise(False))
    w = Window().partitionBy("party_id", f"M{x}").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
    windows.insert(x, w)
    aggs.extend(
        [
            f.min("balance").over(w).alias(f"min{x}"),
            f.max("balance").over(w).alias(f"max{x}"),
            f.mean("balance").over(w).alias(f"mean{x}"),
        ],
    )
    
monbalstats = trans_acctdate.groupBy("party_id", "balance", "M1", "M2", "M3", "M4", "M5", "M6").agg(*aggs)
monbalstats = monbalstats.filter(
    (f.col("M1") == True)
    | (f.col("M2") == True)
    | (f.col("M3") == True)
    | (f.col("M4") == True)
    | (f.col("M5") == True)
    | (f.col("M6") == True)
).drop("balance").dropDuplicates()

aggs = []
for x in range(1, 6 + 1):
    monbalstats = monbalstats.withColumn(
        f"min{x}", f.when(f.col(f"M{x}") == True, f.col(f"min{x}")).otherwise(f.lit(None))
    ).withColumn(
        f"max{x}", f.when(f.col(f"M{x}") == True, f.col(f"max{x}")).otherwise(f.lit(None))
    ).withColumn(
        f"mean{x}", f.when(f.col(f"M{x}") == True, f.col(f"mean{x}")).otherwise(f.lit(None))
    )
    aggs.extend(
        [
            f.max(f"min{x}").alias(f"min{x}"),
            f.max(f"max{x}").alias(f"max{x}"),
            f.max(f"mean{x}").alias(f"mean{x}"),
        ],
    )
    
monbalstats = monbalstats.groupBy("party_id").agg(*aggs)

features = features.join(monbalstats, on="party_id", how="left").drop("account_id")

In [None]:
# features.select('party_id').count() # 79

In [None]:
# Convert response var `status` = {A,B,C,D} to `response` = {0,1} (AC good, BD bad):
features = features.withColumn(
    "response", f.when(f.col("status") == "A", 1).when(f.col("status") == "C", 1).otherwise(0)
).drop("status")

# There are credit card features, but not all clients have cards so these features can be Nan,
# which isn't acceptable in the modeling.  Let's create a `has_card`={0,1} feature, drop the
# date the card was opened, and then below we'll still use the type_card feature in a way
# that avoids NaNs.
features = features.withColumn(
    "has_card", f.when(f.col("issued").isNotNull(), 1).otherwise(0)
).drop("issued")

In [None]:
features = features.drop("tr_timestamp", "tr_timestamp_client", "tr_timestamp_bank", "tr_timestamp_y", "tr_timestamp_x")

In [None]:
features = features.withColumn(
    "has_card", f.col("has_card").cast("long")
).withColumn(
    "response", f.col("response").cast("long")
)
features = features.coalesce(4)

### Write party wrangling ds

In [None]:
dataset_functions.write(context, features, "party_wrangling")

In [None]:
context.close()