In [42]:
from pyspark.sql import SparkSession, DataFrame
import logging
from pyspark.sql.utils import AnalysisException
from pyspark.sql import functions as F

In [35]:
spark = SparkSession.builder \
        .appName("StackOverFlow answers") \
        .getOrCreate()
    

In [36]:
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

In [41]:
def ReadDataframe(path=None, format="csv") -> DataFrame:
    if not isinstance(path, str) or not path:
        raise ValueError("** LOAD ** : The 'path' parameter must be a non-empty string.")

    try:
        df = spark.read.format(format).load(path)
        return df

    except AnalysisException as e:
        logging.error("** LOAD ** : Error loading data into DataFrame. AnalysisException: %s", e)
        raise

    except Exception as e:
        logging.error("** LOAD ** : An unexpected error occurred while loading data into DataFrame: %s", e)
        raise

def show(df):
    logging.info("** LOAD ** : Data from path %s was successfully loaded into a DataFrame", path)
    df.show()

# Usage
path = "C:/Education/PySpark_Learning/data/data_per_100k_habitants.csv"
df = ReadDataframe(path=path)
show(df)



+----+--------------------+----+---------+---------------+--------------------+----+--------------------+----+---------+---------------+--------------------+----+--------------------+----+---------+---------------+--------------------+----+--------------------+----+---------+---------------+--------------------+----+--------------------+----+---------+---------------+--------------------+----+--------------------+----+---------+---------------+--------------------+----+--------------------+----+---------+---------------+--------------------+----+--------------------+----+---------+---------------+--------------------+----+--------------------+----+---------+---------------+--------------------+
| _c0|                 _c1| _c2|      _c3|            _c4|                 _c5| _c6|                 _c7| _c8|      _c9|           _c10|                _c11|_c12|                _c13|_c14|     _c15|           _c16|                _c17|_c18|                _c19|_c20|     _c21|           _c22|    

In [52]:
data = [
    (1, 12.5, 5.5, 9.5),
    (2, 3.0, 14.0, 6.7)
]

columns = ["id", "trx_holiday", "trx_takeout", "trx_pet"]

# Create DataFrame
df = spark.createDataFrame(data, columns)

# Create a list of columns to find the max values from
value_columns = ["trx_holiday", "trx_takeout", "trx_pet"]

# Get the highest value column
df = df.withColumn("max_value", F.greatest(*value_columns))


# Get the name of the column with the highest value
cond_max = F.when(F.col("trx_holiday") == F.col("max_value"), F.lit("trx_holiday")) \
            .when(F.col("trx_takeout") == F.col("max_value"), F.lit("trx_takeout")) \
            .when(F.col("trx_pet") == F.col("max_value"), F.lit("trx_pet"))

df = df.withColumn("MAX", cond_max)
cols = df.select("MAX").rdd.flatMap(lambda x: x).collect()
print(cols)

# test_df = df.drop()
# # Remove the max value column for finding the second max value
# df = df.withColumn("temp", F.array(*[F.struct(F.lit(c).alias("col"), F.col(c).alias("val")) for c in value_columns])) \
#        .withColumn("temp", F.expr("filter(temp, x -> x.col != MAX)")) \
#        .withColumn("max_value_2", F.expr("aggregate(temp, -1.0, (acc, x) -> greatest(acc, x.val))")) \
#        .withColumn("MAX_2", F.expr("filter(temp, x -> x.val = max_value_2)[0].col"))

# # Drop the temporary column
# df = df.drop("temp")

df.show()

['trx_holiday', 'trx_takeout']
+---+-----------+-----------+-------+---------+-----------+
| id|trx_holiday|trx_takeout|trx_pet|max_value|        MAX|
+---+-----------+-----------+-------+---------+-----------+
|  1|       12.5|        5.5|    9.5|     12.5|trx_holiday|
|  2|        3.0|       14.0|    6.7|     14.0|trx_takeout|
+---+-----------+-----------+-------+---------+-----------+

