In [None]:
import json
import networkx as nx
from pyspark.sql import functions as sf
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql import Window

In [None]:
with open("../config.json") as f:
    config = json.load(f)

DATASET = config["dataset"]
PREPROCESSED_DATA = f"../data/01-ibm-transactions-for-aml/preprocessed/{DATASET}-transactions"
WRITE_LOCATION = f"../data/01-ibm-transactions-for-aml/feature_engineering/{DATASET}-features"

In [None]:
config = [
    ("spark.jars.packages", "graphframes:graphframes:0.8.3-spark3.5-s_2.13"),
    ("spark.driver.memory", "8g"),
    ("spark.worker.memory", "8g"),
]
spark = SparkSession.builder.appName("feature_engineering").config(conf=SparkConf().setAll(config)).getOrCreate()

In [None]:
def rename_columns(dataframe, names):
    for name, new_name in names.items():
        dataframe = dataframe.withColumnRenamed(name, new_name)
    return dataframe

In [None]:
data = spark.read.parquet(PREPROCESSED_DATA)

In [None]:
balance_data = data.select(sf.col("source").alias("account"), 
                        (-sf.col("amount")).alias("balance"),
                        sf.col("timestamp"), sf.col("transaction_id")
                        ).union(
                            data.select(sf.col("target").alias("account"),
                                        sf.col("amount").alias("balance"),
                                        sf.col("timestamp"), sf.col("transaction_id")))

balance_data = balance_data.repartition("account")

windowval = (Window.partitionBy("account").orderBy("timestamp").rangeBetween(Window.unboundedPreceding, 0))
balance_data = balance_data.withColumn("cum_sum", sf.sum("balance").over(windowval))

df_negative = balance_data.filter(sf.col("balance") < 0) \
    .withColumnRenamed("account", "source") \
    .withColumnRenamed("balance", "src_increment") \
    .withColumnRenamed("cum_sum", "new_src_balance")

df_positive = balance_data.filter(sf.col("balance") > 0) \
    .withColumnRenamed("account", "target") \
    .withColumnRenamed("balance", "dst_increment") \
    .withColumnRenamed("cum_sum", "new_dst_balance")

data = data.repartition("transaction_id", "timestamp").join(df_negative, on=["transaction_id", "timestamp", "source"], how="left") \
        .join(df_positive, on=["transaction_id", "timestamp", "target"], how="left")

data = data.withColumn("src_increment", sf.round(sf.col("src_increment"), 2)) \
        .withColumn("dst_increment", sf.round(sf.col("dst_increment"), 2)) \
        .withColumn("new_src_balance", sf.round(sf.col("new_src_balance"), 2)) \
        .withColumn("new_dst_balance", sf.round(sf.col("new_dst_balance"), 2))

data = data.withColumn('src_increment', sf.when(sf.col('source') == sf.col('target'), 0).otherwise(sf.col('src_increment'))) \
            .withColumn('dst_increment', sf.when(sf.col('source') == sf.col('target'), 0).otherwise(sf.col('dst_increment'))) \
            .withColumn('old_src_balance', sf.col('new_src_balance') - sf.col('src_increment')) \
            .withColumn('old_dst_balance', sf.col('new_dst_balance') - sf.col('dst_increment')) \
            .withColumn('new_src_balance', sf.when(sf.col('source') == sf.col('target'), sf.col('new_dst_balance')).otherwise(sf.col('new_src_balance'))) \
            .withColumn('old_src_balance', sf.when(sf.col('source') == sf.col('target'), sf.col('old_dst_balance')).otherwise(sf.col('old_src_balance')))

In [None]:
sub_df = data.select('source', 'target', 'amount')

src_group = sub_df.groupBy('source').agg(
                        sf.sum('amount').alias('total_sent'),
                        sf.avg('amount').alias('avg_sent'),
                        sf.stddev('amount').alias('stddev_sent'),
                        sf.countDistinct('target').alias('src_total_counterparties'))

dst_group = sub_df.groupBy('target').agg(
                        sf.sum('amount').alias('total_received'),
                        sf.avg('amount').alias('avg_received'),
                        sf.stddev('amount').alias('stddev_received'),
                        sf.countDistinct('source').alias('dst_total_counterparties'))

data = (data.join(src_group, on='source', how='left')
            .join(dst_group, on='target', how='left'))

data = data.withColumn('percentage_of_total_sent', (sf.col('amount') / sf.col('total_sent')) * 100) \
            .withColumn('percentage_of_total_received', (sf.col('amount') / sf.col('total_received')) * 100)\
            .withColumn('percentage_of_avg_sent', (sf.col('amount') / sf.col('avg_sent')) * 100)\
            .withColumn('percentage_of_avg_received', (sf.col('amount') / sf.col('avg_received')) * 100)

In [None]:
data.write.mode("overwrite").parquet(WRITE_LOCATION)

In [None]:
data = spark.read.parquet(WRITE_LOCATION)

In [None]:
data = data.withColumn('timestamp', sf.to_timestamp(sf.col('timestamp'), 'yyyy-MM-dd HH:mm'))\
            .withColumn('day', sf.dayofmonth(sf.col('timestamp')))\
            .withColumn('week', sf.weekofyear(sf.col('timestamp')))\
            .withColumn('date', sf.to_date(sf.col('timestamp')))

In [None]:
src_daily = data.groupBy('source', 'day').agg(
                        sf.sum('amount').alias('src_daily_paid_amount'),
                        sf.count('*').alias('src_daily_trans_count'),
                        sf.avg('amount').alias('src_daily_avg_paid_amount'),
                        sf.stddev('amount').alias('src_daily_std_amount'),
                        sf.countDistinct('target').alias('src_daily_counterparties'))

dst_daily = data.groupBy('target', 'day').agg(
                        sf.sum('amount').alias('dst_daily_received_amount'),
                        sf.count('*').alias('dst_daily_trans_count'),
                        sf.avg('amount').alias('dst_daily_avg_received_amount'),
                        sf.stddev('amount').alias('dst_daily_std_amount'),
                        sf.countDistinct('source').alias('dst_daily_counterparties'))

data = (data.join(src_daily, on=['source', 'day'], how='left')
            .join(dst_daily, on=['target', 'day'], how='left'))

In [None]:
EXPANSION_DAYS = 10
ROLLING_D = 5

In [None]:
df_to_expand = data.select("source", "target", "amount", "timestamp", "transaction_id", "date")

In [None]:
def expand_dates(data, mode, direction):
    date_range_df = data.groupBy(mode).agg(
        sf.min('date').alias('min_date'),
        sf.max('date').alias('max_date')
    )
    if direction == "back":
        date_range_df = date_range_df.withColumn(
            'date_sequence',
            sf.explode(
                sf.sequence(
                    sf.date_add(sf.col('min_date'), -EXPANSION_DAYS),
                    sf.col('max_date'),
                    sf.expr('INTERVAL 1 DAY')
                )
            )
        ).select(
            sf.col(mode),
            sf.col('date_sequence').alias('date')
        )
    elif direction == "fwd":
        date_range_df = date_range_df.withColumn(
            'date_sequence',
            sf.explode(
                sf.sequence(
                    sf.col('min_date'),
                    sf.date_add(sf.col('max_date'), EXPANSION_DAYS),
                    sf.expr('INTERVAL 1 DAY')
                )
            )
        ).select(
            sf.col(mode),
            sf.col('date_sequence').alias('date')
        )

    return date_range_df

In [None]:
src_daily_totals = df_to_expand.groupBy('source', 'date').agg(sf.sum('amount').alias('out_amount'))
dst_daily_totals = df_to_expand.groupBy('target', 'date').agg(sf.sum('amount').alias('in_amount'))

src_dates_back = expand_dates(src_daily_totals, "source", "back")
dst_dates_back = expand_dates(dst_daily_totals, "target", "back")

src_dates_fwd = expand_dates(src_daily_totals, 'source', "fwd")
dst_dates_fwd = expand_dates(dst_daily_totals, 'target', "fwd")

src_daily_totals_back = src_daily_totals.join(src_dates_back, on=['source', 'date'], how='right').fillna(0)
dst_daily_totals_back = dst_daily_totals.join(dst_dates_back, on=['target', 'date'], how='right').fillna(0)

src_daily_totals_fwd = src_daily_totals.join(src_dates_fwd, on=['source', 'date'], how='right').fillna(0)
dst_daily_totals_fwd = dst_daily_totals.join(dst_dates_fwd, on=['target', 'date'], how='right').fillna(0)

del src_dates_back, dst_dates_back, src_dates_fwd, dst_dates_fwd

In [None]:
back_rolling = f'sum_last_{ROLLING_D}_d'
back_percent =  f'percentage_of_total_last_{ROLLING_D}_d'

fwd_rolling = f'sum_next_{ROLLING_D}_d'
fwd_percent =  f'percentage_of_total_next_{ROLLING_D}_d'

In [None]:
def calculate_percent(data, partition_col, amount_col, rolling_col_name, percent_col_name, ROLLING_D, direction='back'):
    if direction == 'back':
        window_spec = Window.partitionBy(partition_col).orderBy('date').rowsBetween(-(ROLLING_D - 1), 0)
    elif direction == 'fwd':
        window_spec = Window.partitionBy(partition_col).orderBy('date').rowsBetween(0, ROLLING_D - 1)
    
    data = data.withColumn(
        rolling_col_name,
        sf.sum(amount_col).over(window_spec)
    )
    data = data.withColumn(
        percent_col_name,
        (sf.col(amount_col) / sf.col(rolling_col_name) * 100)
    )
    data = data.filter(sf.col(percent_col_name).isNotNull())
    
    return data

In [None]:
src_daily_totals_back = calculate_percent(
    src_daily_totals_back,
    partition_col='source',
    amount_col='out_amount',
    rolling_col_name=back_rolling,
    percent_col_name=back_percent,
    ROLLING_D=ROLLING_D,
    direction='back'
)

dst_daily_totals_back = calculate_percent(
    dst_daily_totals_back,
    partition_col='target',
    amount_col='in_amount',
    rolling_col_name=back_rolling,
    percent_col_name=back_percent,
    ROLLING_D=ROLLING_D,
    direction='back'
)

src_daily_totals_fwd = calculate_percent(
    src_daily_totals_fwd,
    partition_col='source',
    amount_col='out_amount',
    rolling_col_name=fwd_rolling,
    percent_col_name=fwd_percent,
    ROLLING_D=ROLLING_D,
    direction='fwd'
)

dst_daily_totals_fwd = calculate_percent(
    dst_daily_totals_fwd,
    partition_col='target',
    amount_col='in_amount',
    rolling_col_name=fwd_rolling,
    percent_col_name=fwd_percent,
    ROLLING_D=ROLLING_D,
    direction='fwd'
)

In [None]:
src_daily_totals = (src_daily_totals
                    .join(src_daily_totals_back, on=['source', 'date'], how='left')
                    .join(src_daily_totals_fwd, on=['source', 'date'], how='left')
                    .drop('out_amount', f'sum_last_{ROLLING_D}_d', f'sum_next_{ROLLING_D}_d'))

dst_daily_totals = (dst_daily_totals
                    .join(dst_daily_totals_back, on=['target', 'date'], how='left')
                    .join(dst_daily_totals_fwd, on=['target', 'date'], how='left')
                    .drop('in_amount', f'sum_last_{ROLLING_D}_d', f'sum_next_{ROLLING_D}_d'))

In [None]:
mapping_src = {f'percentage_of_total_last_{ROLLING_D}_d': f'percentage_of_total_last_{ROLLING_D}_d_sent',
           f'percentage_of_total_next_{ROLLING_D}_d': f'percentage_of_total_next_{ROLLING_D}_d_sent'}

mapping_dst = {f'percentage_of_total_last_{ROLLING_D}_d': f'percentage_of_total_last_{ROLLING_D}_d_received',
              f'percentage_of_total_next_{ROLLING_D}_d': f'percentage_of_total_next_{ROLLING_D}_d_received'}

src_daily_totals = rename_columns(src_daily_totals, mapping_src)
dst_daily_totals = rename_columns(dst_daily_totals, mapping_dst)

In [None]:
data = ( data
        .join(src_daily_totals, on=['source', 'date'], how='left')
        .join(dst_daily_totals, on=['target', 'date'], how='left'))

del src_daily_totals, dst_daily_totals

In [None]:
data = data.drop('date')

In [None]:
data = data.withColumn(f'percentage_of_total_last_{ROLLING_D}_d_sent', sf.col('amount') / sf.col('src_daily_paid_amount') * sf.col(f'percentage_of_total_last_{ROLLING_D}_d_sent'))\
            .withColumn(f'percentage_of_total_last_{ROLLING_D}_d_received', sf.col('amount') / sf.col('dst_daily_received_amount') * sf.col(f'percentage_of_total_last_{ROLLING_D}_d_received'))\
            .withColumn(f'percentage_of_total_next_{ROLLING_D}_d_sent', sf.col('amount') / sf.col('src_daily_paid_amount') * sf.col(f'percentage_of_total_next_{ROLLING_D}_d_sent'))\
            .withColumn(f'percentage_of_total_next_{ROLLING_D}_d_received', sf.col('amount') / sf.col('dst_daily_received_amount') * sf.col(f'percentage_of_total_next_{ROLLING_D}_d_received'))

In [None]:
data.write.mode("overwrite").parquet(WRITE_LOCATION)

In [None]:
data = spark.read.parquet(WRITE_LOCATION)

In [None]:
from_to_count = data.groupBy('source_bank', 'target_bank').agg(sf.count('*').alias('src_bank_as_src_with_this_dst_bank'))
from_total_count = data.groupBy('source_bank').agg(sf.count('*').alias('src_bank_as_src'))
to_from_count = data.groupBy('target_bank', 'source_bank').agg(sf.count('*').alias('dst_bank_as_dst_with_this_src_bank'))
to_total_count = data.groupBy('target_bank').agg(sf.count('*').alias('dst_bank_as_dst'))

ratio_from_df = from_to_count.join(from_total_count, on='source_bank')
ratio_to_df = to_from_count.join(to_total_count, on='target_bank')

In [None]:
ratio_from_df = ratio_from_df.withColumn('from_to_ratio', (sf.col('src_bank_as_src_with_this_dst_bank') / sf.col('src_bank_as_src')) * 100)
ratio_from_df = ratio_from_df.drop('src_bank_as_src_with_this_dst_bank', 'src_bank_as_src')

ratio_to_df = ratio_to_df.withColumn('to_from_ratio',(sf.col('dst_bank_as_dst_with_this_src_bank') / sf.col('dst_bank_as_dst')) * 100)
ratio_to_df = ratio_to_df.drop('dst_bank_as_dst_with_this_src_bank', 'dst_bank_as_dst')

In [None]:
data = (data
        .join(ratio_from_df, on=['source_bank', 'target_bank'], how='left')
        .join(ratio_to_df, on=['target_bank', 'source_bank'], how='left')) 

In [None]:
del from_to_count, from_total_count, to_from_count, to_total_count, ratio_from_df, ratio_to_df

In [None]:
data = data.withColumn('is_cash', sf.when(sf.col('format') == 'Cash', 1).otherwise(0))

# TODO: add the following features in other datasets
# is_cash_credit, is_cash_debit

In [None]:
ROUND_AMOUNT_MULTIPLIER = 50
ROUND_AMOUNT_SLACK = 20
ROUND_AMOUNT_MIN_VOLUME = 1000

def is_round_amount(value):
    if value is None:
        return 0
    return int((value > ROUND_AMOUNT_MIN_VOLUME) and\
    (
        (value % ROUND_AMOUNT_MULTIPLIER >= ROUND_AMOUNT_MULTIPLIER - ROUND_AMOUNT_SLACK)
        or 
        (value % ROUND_AMOUNT_MULTIPLIER <= ROUND_AMOUNT_SLACK)
    ))

is_round_amount_udf = sf.udf(is_round_amount)
data = data.withColumn('is_round_amount', is_round_amount_udf(sf.col('amount')))

In [None]:
data = data.withColumn('is_international', sf.when(sf.col('source_currency') != sf.col('target_currency'), 1).otherwise(0))

In [None]:
data = data.withColumn('is_different_bank', sf.when(sf.col('source_bank') != sf.col('target_bank'), 1).otherwise(0))\
            .withColumn('same_account_same_bank', sf.when((sf.col('source') == sf.col('target')) & (sf.col('source_bank') == sf.col('target_bank')), 1).otherwise(0))\
            .withColumn('same_bank_diff_account', sf.when((sf.col('source') != sf.col('target')) & (sf.col('source_bank') == sf.col('target_bank')), 1).otherwise(0))

In [None]:
src_dst_day = data.groupBy('source', 'target', 'day').agg(
                                    sf.count('*').alias('daily_trans_count_counterparty'),
                                    sf.avg('amount').alias('daily_avg_paid_amount_counterparty'))

data = data.join(src_dst_day, on=['source', 'target', 'day'], how='left')
del src_dst_day

In [None]:
data = data.drop('day')

In [None]:
total_interactions = data.groupBy('source', 'target').count().withColumnRenamed('count', 'total_interactions')
weekly_interactions = data.groupBy('source', 'target', 'week').count().withColumnRenamed('count', 'weekly_interactions')

data = (data.join(total_interactions, on=['source', 'target'], how='left')
        .join(weekly_interactions, on=['source', 'target', 'week'], how='left'))
del total_interactions, weekly_interactions

In [None]:
src_week = data.groupBy('source', 'week').agg(
        sf.stddev('amount').alias('src_weekly_std_amount'),
        sf.countDistinct('target').alias('src_weekly_counterparties'))

target_week = data.groupBy('target', 'week').agg(
        sf.stddev('amount').alias('dst_weekly_std_amount'),
        sf.countDistinct('source').alias('dst_weekly_counterparties'))

data = (data.join(src_week, on=['source', 'week'], how='left')
        .join(target_week, on=['target', 'week'], how='left'))

In [None]:
data = data.withColumn('src_percentage_of_counterparty', (1 / sf.col('src_total_counterparties')) * 100)\
            .withColumn('dst_percentage_of_counterparty', (1 / sf.col('dst_total_counterparties')) * 100)\
            .withColumn('src_daily_percentage_of_counterparty', (1 / sf.col('src_daily_counterparties')) * 100)\
            .withColumn('dst_daily_percentage_of_counterparty', (1 / sf.col('dst_daily_counterparties')) * 100)\
            .withColumn('src_weekly_percentage_of_counterparty', (1 / sf.col('src_weekly_counterparties')) * 100)\
            .withColumn('dst_weekly_percentage_of_counterparty', (1 / sf.col('dst_weekly_counterparties')) * 100)

In [None]:
weekly_sent = data.groupBy('source', 'week').agg(sf.sum('amount').alias('weekly_total_sent'))
weekly_received = data.groupBy('target', 'week').agg(sf.sum('amount').alias('weekly_total_received'))

data = (data.join(weekly_sent, on=['source', 'week'], how='left')
            .join(weekly_received, on=['target', 'week'], how='left'))

del weekly_sent, weekly_received

data = data.withColumn('percentage_weekly_sent', (sf.col('amount') / sf.col('weekly_total_sent'))* 100)
data = data.withColumn('percentage_weekly_received', (sf.col('amount') / sf.col('weekly_total_received')) * 100)

In [None]:
data = data.drop('week')

In [None]:
data = data.withColumn('percentage_daily_sent', (sf.col('amount') / sf.col('src_daily_paid_amount')) * 100)
data = data.withColumn('percentage_daily_received', (sf.col('amount') / sf.col('dst_daily_received_amount')) * 100)

In [None]:
data = data.withColumn('percentage_of_src_balance', sf.when(sf.col('old_src_balance') == 0, 0).otherwise(sf.col('amount') / sf.col('old_src_balance') * 100))
data = data.withColumn('percentage_of_dst_balance', sf.when(sf.col('new_dst_balance') == 0, 0).otherwise(sf.col('amount') / sf.col('new_dst_balance') * 100))

In [None]:
src_interactions = data.groupBy('source').agg(sf.count('*').alias('src_interactions'))
dst_interactions = data.groupBy('target').agg(sf.count('*').alias('dst_interactions'))

data = (data.join(src_interactions, on='source', how='left')
            .join(dst_interactions, on='target', how='left'))
del src_interactions, dst_interactions

data = data.withColumn('src_percentage_of_interactions', sf.col('total_interactions') / sf.col('src_interactions') * 100)
data = data.withColumn('dst_percentage_of_interactions', sf.col('total_interactions') / sf.col('dst_interactions') * 100)

In [None]:
source_window = Window.partitionBy("source").orderBy("timestamp")
target_window = Window.partitionBy("target").orderBy("timestamp")

data = data.withColumn("src_time_diff", 
                       (sf.unix_timestamp(sf.col("timestamp")) - sf.unix_timestamp(sf.lag("timestamp", 1).over(source_window))).cast("double"))\
            .withColumn("dst_time_diff", 
                       (sf.unix_timestamp(sf.col("timestamp")) - sf.unix_timestamp(sf.lag("timestamp", 1).over(target_window))).cast("double"))
data = data.na.fill({"src_time_diff": 0, "dst_time_diff": 0})

In [None]:
data = data.fillna(0)

In [None]:
data.write.mode("overwrite").parquet(WRITE_LOCATION)

In [None]:
data = spark.read.parquet(WRITE_LOCATION)

In [None]:
data_aggregated_all = data.groupBy("source", "target").agg(
        sf.sum("amount").alias("sent_amount"),
        sf.sum("amount").alias("received_amount")
)

mapping_source = data_aggregated_all.groupBy("source").agg(sf.sum("sent_amount").alias("total_sent_by_source"))
mapping_target = data_aggregated_all.groupBy("target").agg(sf.sum("received_amount").alias("total_received_by_target"))

data_aggregated_all = data_aggregated_all.join(mapping_source, on="source", how="left")
data_aggregated_all = data_aggregated_all.join(mapping_target, on="target", how="left")

data_aggregated_all = data_aggregated_all.withColumn(
    "weight",
    (sf.col("sent_amount") / sf.col("total_sent_by_source"))+
    (sf.col("received_amount") / sf.col("total_received_by_target")))

edges = data_aggregated_all.select("source", "target", "weight").toPandas()



G = nx.from_pandas_edgelist(edges, source="source", target="target", edge_attr="weight", create_using=nx.DiGraph)
pagerank = nx.pagerank(G)
degree_centrality = nx.degree_centrality(G)

pagerank_df = spark.createDataFrame([(k, v) for k, v in pagerank.items()], ["node", "pagerank"])
degree_centrality_df = spark.createDataFrame([(k, v) for k, v in degree_centrality.items()], ["node", "degree_centrality"])

data = data.join(pagerank_df.withColumnRenamed("node", "source").withColumnRenamed("pagerank", "src_pr"), on="source", how="left")
data = data.join(pagerank_df.withColumnRenamed("node", "target").withColumnRenamed("pagerank", "dst_pr"), on="target", how="left")
data = data.join(degree_centrality_df.withColumnRenamed("node", "source").withColumnRenamed("degree_centrality", "src_deg_centr"), on="source", how="left")
data = data.join(degree_centrality_df.withColumnRenamed("node", "target").withColumnRenamed("degree_centrality", "dst_deg_centr"), on="target", how="left")

In [None]:
data.write.mode("overwrite").parquet(WRITE_LOCATION)
spark.stop()