In [0]:
_ = '''
initial data
filter out customers where cust_id is null
target table, big table
valid cust set, invalid cust set
filter for valid customers
populate target table with valid customers
'''

In [0]:
from pyspark.sql import SparkSession, Row, functions
from pyspark.sql.types import (
    StructField
    , StringType
    , IntegerType
    , DoubleType
    , StructType
)

In [0]:
spark = SparkSession.builder.appName("Mini_Data_Model_Test").getOrCreate()

In [0]:
table_A_data = {
    1: (1769, 6421.79, 0, 'Dave', 30)
    , 2: (8341, 1137.20, 1, 'Bob', 42)
    , 3: (4507, 911.93, 0, 'Jess', 21)
    , 4: (None, 1234.0, 0, 'Lisa', 37)
    , 5: (4405, 7891.37, None, 'Andy', 38)
    , 6: (6666, 1234.0, 0, None, 99)
}

table_B_data = {
    0: (1769, 30, 6421.79, 0, None)
    , 1: (4507, 21, 911.93, 0, 48)
    , 2: (4405, 38, 7891.37, 1, 587)
    , 3: (8341, 42, 10.56, 1, 999)
    , 4: (7777, 56, 4939.32, None, 1233)
    , 5: (6666, 98, 1234.0, 0, 1234)
}

table_C_data = {
    4232: (None, 58.58, 23)
    , 9947: (687, 3400.21, 66)
    , 8888: (951, 78.29, 18)
    , 6666: (543, 435.0, 97)
}

In [0]:
dfA = spark.createDataFrame([
        Row(
            idx=i
            , cust_id=r[0]
            , balance_A=r[1]
            , churn_A=r[2]
            , name_A=r[3]
            , age_A=r[4]
        )
        for i,r in table_A_data.items() if r[0] != None
])

dfB = spark.createDataFrame([
        Row(
            idx=i
            , cust_id=r[0]
            , age_B=r[1]
            , balance_B=r[2]
            , churn_B=r[3]
            , c_score_B=r[4]
        )
        for i,r in table_B_data.items() if r[0] != None
])

dfC = spark.createDataFrame([
        Row(
            cust_id=i
            , c_score_C=r[0]
            , balance_C=r[1]
            , age_C=r[2]
        )
        for i,r in table_C_data.items() if i != None
])

In [0]:
targ_df = spark.createDataFrame([], schema=StructType(fields=[
    StructField('cust_id', IntegerType(), True)
    , StructField('name', StringType(), True)
    , StructField('age', IntegerType(), True)
    , StructField('balance', DoubleType(), True)
    , StructField('c_score', IntegerType(), True)
    , StructField('churn', IntegerType(), True)
]))

In [0]:
full_joined_df = dfA.join(dfB, on='cust_id', how='full').join(dfC, on='cust_id', how='full')
full_joined_df = full_joined_df.withColumn("nulls", functions.lit(None))
full_joined_df.show() # .select(['cust_id', 'name', 'age', 'balance', 'c_score', 'churn'])

+-------+----+---------+-------+------+-----+----+-----+---------+-------+---------+---------+---------+-----+-----+
|cust_id| idx|balance_A|churn_A|name_A|age_A| idx|age_B|balance_B|churn_B|c_score_B|c_score_C|balance_C|age_C|nulls|
+-------+----+---------+-------+------+-----+----+-----+---------+-------+---------+---------+---------+-----+-----+
|   1769|   1|  6421.79|      0|  Dave|   30|   0|   30|  6421.79|      0|     null|     null|     null| null| null|
|   4232|null|     null|   null|  null| null|null| null|     null|   null|     null|     null|    58.58|   23| null|
|   4405|   5|  7891.37|   null|  Andy|   38|   2|   38|  7891.37|      1|      587|     null|     null| null| null|
|   4507|   3|   911.93|      0|  Jess|   21|   1|   21|   911.93|      0|       48|     null|     null| null| null|
|   6666|   6|   1234.0|      0|  null|   99|   5|   98|   1234.0|      0|     1234|      543|    435.0|   97| null|
|   7777|null|     null|   null|  null| null|   4|   56|  4939.3

In [0]:
full_joined_df.select([c for c in full_joined_df.columns if 'bal' in c] + ['nulls']).show()

+---------+---------+---------+-----+
|balance_A|balance_B|balance_C|nulls|
+---------+---------+---------+-----+
|  6421.79|  6421.79|     null| null|
|     null|     null|    58.58| null|
|  7891.37|  7891.37|     null| null|
|   911.93|   911.93|     null| null|
|   1234.0|   1234.0|    435.0| null|
|     null|  4939.32|     null| null|
|   1137.2|    10.56|     null| null|
|     null|     null|    78.29| null|
|     null|     null|  3400.21| null|
+---------+---------+---------+-----+



In [0]:
# outer loop, looping over cust_id
for r in full_joined_df.select([c for c in full_joined_df.columns if 'bal' in c] + ['nulls']).collect():
    if len(set(v for v in r)) > 2:
        print(r) # set cust_id to not valid

Row(balance_A=1234.0, balance_B=1234.0, balance_C=435.0, nulls=None)
Row(balance_A=1137.2, balance_B=10.56, balance_C=None, nulls=None)


In [0]:
valid_cust_id = set()
invalid_cust_id = set()
all_cust_id = set(r[0] for r in full_joined_df.collect())

In [0]:
for c_id in all_cust_id:
    cust = full_joined_df.filter(full_joined_df.cust_id == c_id)
    valid = True
    #check validity name age balance c_score churn
    new_cust = []
    for column_name in targ_df.columns:
        columns = [c for c in full_joined_df.columns if column_name in c] + ['nulls']
        values = set(v for v in cust.select(columns).first())
        if len(values) > 2:
            valid = False
        elif len(values) == 2:
            #print(values)
            new_cust.append([v for v in values if v != None][0])
        elif len(values) == 1:
            new_cust.append(None)
    if valid:
        valid_cust_id.add(c_id)
        data_tuple = tuple(new_cust)
        #print(data_tuple)
        targ_df = targ_df.union(spark.createDataFrame([data_tuple], schema=targ_df.schema))
    else:
        invalid_cust_id.add(c_id)

In [0]:
for e in [valid_cust_id, invalid_cust_id, all_cust_id]:
    print(e)

{7777, 4232, 1769, 4405, 8888, 4507, 9947}
{6666, 8341}
{7777, 4232, 1769, 6666, 4405, 8341, 8888, 4507, 9947}


In [0]:
# populate dataframe
# use len(set(v for v in r)) == 2 to find nonnull values
# if set has 2 elements then
# value = [v for v in set if v][0]

In [0]:
targ_df.show()

+-------+----+---+-------+-------+-----+
|cust_id|name|age|balance|c_score|churn|
+-------+----+---+-------+-------+-----+
|   7777|null| 56|4939.32|   1233| null|
|   4232|null| 23|  58.58|   null| null|
|   1769|Dave| 30|6421.79|   null|    0|
|   4405|Andy| 38|7891.37|    587|    1|
|   8888|null| 18|  78.29|    951| null|
|   4507|Jess| 21| 911.93|     48|    0|
|   9947|null| 66|3400.21|    687| null|
+-------+----+---+-------+-------+-----+



In [0]:
_ = '''

get_keys = lambda _df, _pk: set(row[_pk] for row in _df.collect())
key = 'cust_id'

for row in dfA.collect():
    cust_id = row[key]
    targdf_keys = get_keys(targ_df, key)
    if cust_id:
        if cust_id in targdf_keys:
            overlap = targ_df.select(dfA.columns)
            control_row = overlap.filter(key == cust_id).first()
            existing_row = targ_df.filter(key == cust_id).first()
            targ_df = targ_df.filter(key != cust_id)
            new_row = {c:existing_row[c] for c in targ_df.columns}
            if all():
                for k,v in new_row.items():
                    if v == None:
                        new_row[k] = row[k]
            else:
                data_tuple = (cust_id, row['name'], row['age'], row['balance'], None, row['churn'], 'A')
                rejected_df = rejected_df.union(spark.createDataFrame([data_tuple], schema=rejected_df.schema))
        else:
            data_tuple = (cust_id, row['name'], row['age'], row['balance'], None, row['churn'])
            targ_df = targ_df.union(spark.createDataFrame([data_tuple], schema=targ_df.schema))
    else:
        data_tuple = (cust_id, row['name'], row['age'], row['balance'], None, row['churn'], 'A')
        rejected_df = rejected_df.union(spark.createDataFrame([data_tuple], schema=rejected_df.schema))

'''