In [None]:
import datetime
import dateutil.relativedelta

from thetaray.api.context import init_context
from thetaray.common import Constants
from pyspark.sql import SQLContext
import json

from pyspark.sql import functions as f, Window, DataFrame
from pyspark.sql.types import MapType, StringType

from thetaray.api.solution import IngestionMode

from thetaray.api.dataset import dataset_functions
from thetaray.common.data_environment import DataEnvironment

from common.libs.tr_levenshtein import get_lev_ind

import logging
logging.getLogger().handlers[0].setFormatter(logging.Formatter(fmt='%(levelname)s: %(asctime)s @ %(message)s',datefmt='%Y-%m-%d %H:%M:%S'))
logging.basicConfig(level=logging.INFO)

from common.libs.config.loader import load_config
from common.libs.config.basic_execution_config_loader import BasicExecutionConfig, DevBasicExecutionConfig
from common.libs.context_utils import is_run_triggered_from_airflow

execution_date=Constants.BEGINNING_OF_TIME

if is_run_triggered_from_airflow():
    context = init_context()
    basic_execution_config = BasicExecutionConfig(domain=context.domain,
                                                  stage=context.parameters['stage'],
                                                  cadence=context.parameters["cadence"],
                                                  entity=context.parameters['entity'],
                                                  spark_conf=context.spark_conf)
else:
    basic_execution_config = DevBasicExecutionConfig()
    context = init_context(execution_date=execution_date,
                           domain=basic_execution_config.domain,
                           spark_conf=basic_execution_config.spark_conf)

print(basic_execution_config)
spark = context.get_spark_session()
sc = SQLContext(spark)
params = context.parameters
print(f"Spark UI URL: {context.get_spark_ui_url()}")

print(json.dumps(params, indent=4))

missing_values = ['', 'none', 'null']

In [None]:
keywords_df = dataset_functions.read(context, 'keywords')
country_risk_df = dataset_functions.read(context, 'country_risk')

print(f'keywords_df count: {keywords_df.count()}')
print(f'country_risk_df count: {country_risk_df.count()}')

In [None]:
joined_basic_trx = dataset_functions.read(context, 'trx_basic').drop('tr_timestamp')
trx_basic_count = joined_basic_trx.count()
print(f'joined_basic_trx count: {trx_basic_count}')
if trx_basic_count == 0:
    raise Exception('trx_basic count is 0, aborting run')

In [None]:
#Me to me
threshold = 1
joined_basic_trx = joined_basic_trx.withColumn("is_me_to_me", get_lev_ind('creditor_name', 'debtor_name', threshold))
print("Me-to-me field added")

In [None]:
#Keywords
from common.libs.feature_engineering_computations_utils import enrich_trx_with_keywords
#TODO

In [None]:
# TODO: Join with country_risk dataset 

In [None]:
# TODO: add additional columns

### Schema padding

In [None]:
from thetaray.utils.type_utils import convert_tr_type_to_spark_type
from common.libs.context_utils import get_dataset
trx_enriched_ds = get_dataset(context, 'trx_enriched')

null_fields = [field for field in trx_enriched_ds.field_list if field.identifier not in joined_basic_trx.columns]
for field in null_fields:
    joined_basic_trx = joined_basic_trx.withColumn(field.identifier, f.lit(None).cast(convert_tr_type_to_spark_type(field)))
joined_basic_trx = joined_basic_trx.select([field.identifier for field in trx_enriched_ds.field_list])

In [None]:
joined_basic_trx = joined_basic_trx.withColumn('year_month', f.date_trunc('month', 'delivery_timestamp'))

In [None]:
joined_basic_trx.printSchema()
print('writing trx_enriched')
dataset_functions.write(context, joined_basic_trx, 'trx_enriched')

In [None]:
# print('publishing trx_enriched')
# dataset_functions.publish(context, 'trx_enriched')

In [None]:
context.close()