In [1]:
import findspark, os, pandas as pd
os.environ["SPARK_HOME"]="/Users/ankitkansal/spark/spark-2.4.4-bin-without-hadoop"
findspark.init()

In [2]:
import pyspark
import datetime
import re
from functools import reduce
import pandas as pd
from pyspark import SparkContext, sql
from pyspark.sql import functions as F, SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import *
from pyspark.sql.functions import array, col, explode, lit, struct
from pyspark.sql import DataFrame
from typing import Iterable

In [3]:
spark = SparkSession.builder.getOrCreate()

In [3]:
def melt(df: DataFrame,
         id_vars: Iterable[str], value_vars: Iterable[str],
         var_name: str = "variable", value_name: str = "value") -> DataFrame:
    """Convert :class:`DataFrame` from wide to long format."""
    # Create array<struct<variable: str, value: ...>>
    _vars_and_vals = array(*(
        struct(lit(c).alias(var_name), col(c).alias(value_name))
        for c in value_vars))
    # Add to the DataFrame and explode
    _tmp = df.withColumn("_vars_and_vals", explode(_vars_and_vals))
    cols = id_vars + [
        col("_vars_and_vals")[x].alias(x) for x in [var_name, value_name]]
    return _tmp.select(*cols)

In [4]:
def _build_qc_df_from_func(agg_function, df, label, column_dtype=None):
    """
    Args:
        agg_function:
        df:
        label:
        column_dtype:
    Returns:
    """
    ss = SparkSession.builder.getOrCreate()
    sc = SparkContext.getOrCreate()
    columns = []
    schema = [(x.name, str(x.dataType)) for x in df.schema.fields]
    if not column_dtype:
        columns = df.columns
    elif 'string' in column_dtype:
        columns = columns + [x[0] for x in schema if x[1] in ['StringType']]
    elif 'numeric' in column_dtype:
        columns = columns + [x[0] for x in schema if re.match(
            '|'.join(['DecimalType', 'DoubleType', 'FloatType', 'IntegerType', 'LongType', 'ShortType']), x[1])]
    elif 'date' in column_dtype:
        columns = columns + [x[0] for x in schema if x[1] in ['DateType', 'TimestampType']]
    elif 'bool' in column_dtype:
        columns = columns + [x[0] for x in schema if x[1] in ['BooleanType']]
    elif 'qty' in column_dtype:
        columns = columns + [x[0] for x in schema if 'qty' in x[0]]
    else:
        raise ValueError('unsupported column_dtype argument: {}'.format(column_dtype))
    if len(columns) == 0:
        output = ss.createDataFrame(sc.emptyRDD(),
                                    StructType(
                                        [StructField('field', StringType()), StructField(label, StringType())]))
    else:
        col_batch_list = [columns[x:x + 10] for x in range(0, len(columns), 10)]
        df_list = [df.agg(*[agg_function(x).alias(x) for x in column_batch]) for
                   column_batch in col_batch_list]
        wrking_df = reduce(lambda x, y: x.crossJoin(y), df_list).withColumn('temp', F.lit("DISCARD"))
        melted_df = melt(wrking_df, ['temp'], columns).drop('temp') \
            .withColumnRenamed('value', label) \
            .withColumnRenamed('variable', 'field')
        output = melted_df
    return output

In [13]:
def _generate_qc_summary_table(wrk_df: sql.DataFrame, table_name: str) -> sql.DataFrame:
    """
    Args:
        wrk_df:
    Returns:
    """
    ss = SparkSession.builder.getOrCreate()
    aggregate_stats_pandas = [
        _build_qc_df_from_func(lambda x: F.count(F.col(x)), df=wrk_df, label='n'),
        _build_qc_df_from_func(lambda x: F.countDistinct(F.col(x)), df=wrk_df, label='n_distinct'),
        _build_qc_df_from_func(lambda x: F.sum(F.when(F.col(x).isNull(), 1).otherwise(0)), df=wrk_df,
                               label='is_null_cnt')
        ,
        _build_qc_df_from_func(lambda x: F.sum((F.col(x).isNotNull().cast('integer'))), df=wrk_df,
                               label='is_not_null_cnt'),
        _build_qc_df_from_func(lambda x: F.sum(F.col(x)).cast('string'), df=wrk_df, label='sum',
                               column_dtype=['numeric']),
        _build_qc_df_from_func(lambda x: F.avg(F.col(x)).cast('string'), df=wrk_df, label='mean_val',
                               column_dtype=['numeric']),
        _build_qc_df_from_func(lambda x: F.max(F.col(x)).cast('string'), df=wrk_df, label='max_val',
                               column_dtype=['numeric', 'date']),
        _build_qc_df_from_func(lambda x: F.min(F.col(x)).cast('string'), df=wrk_df, label='min_val',
                               column_dtype=['numeric', 'date']),
        _build_qc_df_from_func(lambda x: F.sum((F.col(x) == F.lit('')).cast('integer')), df=wrk_df,
                               label='is_blank_count',
                               column_dtype=['string'])
    ]
    total_rows = wrk_df.count()
    schema = [(x.name, str(x.dataType)) for x in wrk_df.schema.fields]
    dtypes_df = ss.createDataFrame(schema, ['field', 'type'])
    aggregation_results = reduce(lambda x, y: x.join(y, 'field', 'outer'), aggregate_stats_pandas)
    reduced_df = dtypes_df.join(aggregation_results, 'field', 'left')
    missing_data_cols = ['is_null_cnt', 'is_blank_count']
    results_df = reduced_df \
        .withColumn('overall_missing_values',
                    reduce(lambda x, y: F.coalesce(F.col(x), F.lit(0)) + F.coalesce(F.col(y), F.lit(0)),
                           missing_data_cols)) \
        .withColumn('total_rows', F.lit(total_rows)) \
        .withColumn('overall_missing_pct', F.round((F.col('overall_missing_values') / F.col('total_rows')) * 100, 2))
    results_df = results_df.select("field",
                                   F.col("total_rows").alias("tot_rows"),
                                   F.col("n_distinct").alias("distinct_vals"),
                                   "sum",
                                   F.col("mean_val").alias("mean"),
                                   F.col("max_val").alias("max"),
                                   F.col("min_val").alias("min"),
                                   F.col("overall_missing_values").alias("tot_missing"),
                                   F.col("overall_missing_pct").alias("perc_missing"),
                                   F.lit(table_name).alias("table_name")
                                  )
    return results_df

In [4]:
spark

In [5]:
data = {'col_1': ['ankit', 'kansal'], 'col_2': ['hellp', None]}
df = pd.DataFrame.from_dict(data)

In [6]:
df

Unnamed: 0,col_1,col_2
0,ankit,hellp
1,kansal,


In [8]:
spark_df = spark.createDataFrame(df)

In [15]:
spark_df.withColumn("concat", F.concat_ws(" | ", F.col("col_1"), F.col("col_2"))).show()

+------+-----+-------------+
| col_1|col_2|       concat|
+------+-----+-------------+
| ankit|hellp|ankit | hellp|
|kansal| null|       kansal|
+------+-----+-------------+



In [47]:
new_df = spark.read.csv("/Users/ankitkansal/Desktop/test.csv", header=True)

In [48]:
new_df.toPandas()

Unnamed: 0,PTID,TREATMENT_START,TREATMENT_END,OBS_BEFORETM_ALCOHOL_CONSUMES_FLAG,OBS_BEFORETM_BMI_HIGH_FLAG,OBS_BEFORETM_BMI_NORMAL_FLAG,OBS_BEFORETM_SMOKE_1_FLAG
0,1,20190722 00:00:00,20191217 00:00:00,1,,1.0,1
1,1,20190717 00:00:00,20200217 00:00:00,1,,,1


In [49]:
orignal_df_columns = new_df.columns

In [50]:
super_type_cols = ["ALCOHOL", "BMI", "SMOKE"]

In [51]:
test_df = new_df
for above_col in super_type_cols:
    local_list = []
    new_col_decide = above_col+"_DECIDE"
    for col in orignal_df_columns:
        if above_col in col:
            local_list.append("coalesce({}, 0)".format(col))
        else:
            pass;
    sel_string = " + ".join(local_list)+" as "+new_col_decide
    test_df = test_df.select("*", F.expr(sel_string))

In [68]:
test_df.toPandas()

Unnamed: 0,PTID,TREATMENT_START,TREATMENT_END,OBS_BEFORETM_ALCOHOL_CONSUMES_FLAG,OBS_BEFORETM_BMI_HIGH_FLAG,OBS_BEFORETM_BMI_NORMAL_FLAG,OBS_BEFORETM_SMOKE_1_FLAG,ALCOHOL_DECIDE,BMI_DECIDE,SMOKE_DECIDE
0,1,20190722 00:00:00,20191217 00:00:00,1,,1.0,1,1,1.0,1
1,1,20190717 00:00:00,20200217 00:00:00,1,,,1,1,0.0,1


In [63]:
final_df = test_df
for above_col in super_type_cols:
    local_list = []
    new_col_decide = above_col+"_DECIDE"
    for col in orignal_df_columns:
        if above_col in col:
            temp_str = "CASE WHEN {0} = 1 AND {1} = 1 THEN 1 WHEN {0} = 1 AND {1} IS NULL THEN 0 ELSE NULL END AS {1}".format(new_col_decide, col)
            final_df = final_df.withColumn(col, F.expr(temp_str))
        else:
            pass;

In [66]:
new_df.toPandas()

Unnamed: 0,PTID,TREATMENT_START,TREATMENT_END,OBS_BEFORETM_ALCOHOL_CONSUMES_FLAG,OBS_BEFORETM_BMI_HIGH_FLAG,OBS_BEFORETM_BMI_NORMAL_FLAG,OBS_BEFORETM_SMOKE_1_FLAG
0,1,20190722 00:00:00,20191217 00:00:00,1,,1.0,1
1,1,20190717 00:00:00,20200217 00:00:00,1,,,1


In [67]:
final_df.select(orignal_df_columns).toPandas()

Unnamed: 0,PTID,TREATMENT_START,TREATMENT_END,OBS_BEFORETM_ALCOHOL_CONSUMES_FLAG,OBS_BEFORETM_BMI_HIGH_FLAG,OBS_BEFORETM_BMI_NORMAL_FLAG,OBS_BEFORETM_SMOKE_1_FLAG
0,1,20190722 00:00:00,20191217 00:00:00,1,0.0,1.0,1
1,1,20190717 00:00:00,20200217 00:00:00,1,,,1
