<a href="https://colab.research.google.com/github/abhishek88agnihotri/pyspark_exercise_Colruyt/blob/main/pyspark_exercise_Colruyt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!apt-get install openjdk-11-jdk-headless -qq > /dev/null
!wget -q https://archive.apache.org/dist/spark/spark-3.4.1/spark-3.4.1-bin-hadoop3.tgz
!tar xf spark-3.4.1-bin-hadoop3.tgz
!pip install -q findspark

In [2]:
from os import environ
import findspark

In [3]:
environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64"
environ["SPARK_HOME"] = "/content/spark-3.4.1-bin-hadoop3"

In [4]:
# Init spark
findspark.init()

In [5]:
from pyspark.sql import SparkSession
# spark.sql.repl.eagerEval.enabled: Property used to format output tables better

spark = (
    SparkSession
    .builder
    .appName("cg-pyspark-assignment")
    .master("local")
    .config("spark.sql.repl.eagerEval.enabled", True)
    .getOrCreate()
  )

spark


In [6]:
!curl https://ecgplacesmw.colruytgroup.com/ecgplacesmw/v3/nl/places/filter/clp-places > clp-places.json
!curl https://ecgplacesmw.colruytgroup.com/ecgplacesmw/v3/nl/places/filter/okay-places > okay-places.json
!curl https://ecgplacesmw.colruytgroup.com/ecgplacesmw/v3/nl/places/filter/spar-places > spar-places.json
!curl https://ecgplacesmw.colruytgroup.com/ecgplacesmw/v3/nl/places/filter/dats-places > dats-places.json
!curl https://ecgplacesmw.colruytgroup.com/ecgplacesmw/v3/nl/places/filter/cogo-colpnts > cogo-colpnts.json


  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  223k    0  223k    0     0   179k      0 --:--:--  0:00:01 --:--:--  179k
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  147k    0  147k    0     0   176k      0 --:--:-- --:--:-- --:--:--  175k
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  165k    0  165k    0     0   178k      0 --:--:-- --:--:-- --:--:--  178k
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 89162    0 89162    0     0   139k      0 --:--:-- --:--:-- --:--:--  139k
  % Total    % Received % Xferd  Average Speed   Tim

In [7]:
from logging import getLogger, Logger

In [8]:
LOGGER = getLogger()

In [9]:
import logging
import requests
import os
import json
from pyspark.sql import SparkSession

# Create a logger object
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler('assignment.log')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)

# Define the API endpoints
api_endpoints = {
    "clp": "https://ecgplacesmw.colruytgroup.com/ecgplacesmw/v3/nl/places/filter/clp-places",
    "okay": "https://ecgplacesmw.colruytgroup.com/ecgplacesmw/v3/nl/places/filter/okay-places",
    "spar": "https://ecgplacesmw.colruytgroup.com/ecgplacesmw/v3/nl/places/filter/spar-places",
    "dats": "https://ecgplacesmw.colruytgroup.com/ecgplacesmw/v3/nl/places/filter/dats-places",
    "cogo": "https://ecgplacesmw.colruytgroup.com/ecgplacesmw/v3/nl/places/filter/cogo-colpnts"
}

# Function to download data from API and save as JSON files
def download_data_from_api(api_endpoints, logger):
    for brand, url in api_endpoints.items():
        response = requests.get(url)
        if response.status_code == 200:
            filename = f"{brand}-places.json"
            with open(filename, 'wb') as file:
                file.write(response.content)
            logger.info(f"Data downloaded for {brand} and saved as {filename}")
        else:
            logger.error(f"Failed to download data for {brand}. Status code: {response.status_code}")

# Example function provided
def get_data_by_brand(brand, logger):
    """Fetch input data based on brand.

    Please add a column to the data indicating the input brand
    Please add minimum one sanity check for loading the data
    Please log things you consider relevant

    Args:
        brand: allowed values are (clp, okay, spar, dats, cogo)
        logger: Logger object for logging

    Returns:
        The relevant dataframe
    """
    filename = f"{brand}-places.json"
    try:
        logger.info(f"Processing data for brand: {brand}...")
        # Read JSON file and create DataFrame
        df = spark.read.json(filename)
        # Add brand column
        df = df.withColumn("brand", lit(brand))
        logger.info(f"Data processed for brand: {brand}")
        return df
    except Exception as e:
        logger.error(f"Error processing data for brand: {brand}. Error: {e}")
        raise e

# Download data from API and save as JSON files
download_data_from_api(api_endpoints, logger)

INFO:__main__:Data downloaded for clp and saved as clp-places.json
INFO:__main__:Data downloaded for okay and saved as okay-places.json
INFO:__main__:Data downloaded for spar and saved as spar-places.json
INFO:__main__:Data downloaded for dats and saved as dats-places.json
INFO:__main__:Data downloaded for cogo and saved as cogo-places.json


In [14]:
#from pyspark.sql import SparkSession
from pyspark.sql.types import ArrayType, StructType
from pyspark.sql.functions import lit, col, when, split
from pyspark.sql.types import *
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline
from pyspark.sql import functions as F
import getpass

# Function to adjust schema of DataFrame
def adjust_schema(df, common_columns):
    for col_name in common_columns:
        if col_name not in df.columns:
            df = df.withColumn(col_name, lit(None))
    return df.select(list(common_columns))

# Function to merge schemas of DataFrames
def merge_schemas(schema1, schema2):
    fields = {}
    # Add fields from schema1
    for field in schema1.fields:
        fields[field.name] = field
    # Add fields from schema2, handle conflicts
    for field in schema2.fields:
        if field.name not in fields:
            fields[field.name] = field

    #print(fields)
    return StructType(list(fields.values()))

# Fetch data from API endpoints for all brands and process into DataFrames
dfs = {}
for brand, url in api_endpoints.items():
    df = get_data_by_brand(brand,logger)
    dfs[brand] = df

#print(dfs)

# Merge schemas of DataFrames
combined_schema = None
for df in dfs.values():
    if combined_schema is None:
        combined_schema = df.schema
    else:
        combined_schema = merge_schemas(combined_schema, df.schema)

print('combined schema of all the brands - \n',combined_schema)

# Extract common columns from all DataFrames
common_columns = set(combined_schema.names)
for df in dfs.values():
    common_columns = common_columns.intersection(set(df.columns))

print('extracting common columns of all the brands - \n',common_columns)

# Adjust schemas for each DataFrame
for brand, df in dfs.items():
    dfs[brand] = adjust_schema(df, common_columns)


# Adjust data types of each DataFrame
for brand, df in dfs.items():
    # Drop the 'placeSearchOpeningHours' column if it exists
    if 'placeSearchOpeningHours' in df.columns:
        df = df.drop('placeSearchOpeningHours')
    # Adjust data types of each column
    for field in combined_schema.fields:
        # Skip if the field is 'placeSearchOpeningHours'
        if field.name == 'placeSearchOpeningHours':
            continue
        # Handle specific columns like 'temporaryClosures' separately
        if field.name == 'temporaryClosures':
            # Assuming the original data type of 'temporaryClosures' is an array of strings
            if field.name in df.columns:
                df = df.withColumn(field.name, F.array(F.struct(F.lit(None).alias('from'), F.lit(None).alias('till'))))
        # For other columns, cast to their original data types
        else:
            if isinstance(field.dataType, ArrayType) and isinstance(field.dataType.elementType, StructType):
                original_data_type = field.dataType
                df = df.withColumn(field.name, col(field.name).cast(original_data_type))
    dfs[brand] = df

#print(dfs)

# Perform the union operation
combined_df = None
for df in dfs.values():
    if combined_df is None:
        combined_df = df
    else:
        combined_df = combined_df.union(df)


# Extract "postal_code" from "address"
#combined_df.printSchema()
combined_df = combined_df.withColumn("postal_code", combined_df.address.postalcode)


# Create new column "province" derived from "postal_code"
combined_df = combined_df.withColumn("province", when(combined_df.postal_code.startswith("1"), "Province 1")
                                     .when(combined_df.postal_code.startswith("2"), "Province 2")
                                     .otherwise("Unknown"))


# Transform geoCoordinates into lat and lon column
combined_df = combined_df.withColumn("latitude", combined_df["geoCoordinates"]["latitude"]) \
                         .withColumn("longitude", combined_df["geoCoordinates"]["longitude"])

combined_df.show(5)

# One-hot-encode the handoverServices if the column exists
if "handoverServices" in combined_df.columns:
    indexer = StringIndexer(inputCol="handoverServices", outputCol="handoverServicesIndex")
    encoder = OneHotEncoder(inputCol="handoverServicesIndex", outputCol="handoverServicesVec")
    pipeline = Pipeline(stages=[indexer, encoder])
    combined_df = pipeline.fit(combined_df).transform(combined_df)
    #combined_df.show(5)
else:
    print("Column 'handoverServices' does not exist in the DataFrame.")

# Fetch the current username
user_name = getpass.getuser()

# Anonymize houseNumber and streetName columns for unauthorized users
# Assuming anonymization involves replacing them with dummy values
authorized_users = ["authorized_user1", "authorized_user2"]
if user_name not in authorized_users:
    combined_df = combined_df.withColumn("houseNumber", lit("GDPR_SENSITIVE_VALUE")) \
                             .withColumn("streetName", lit("GDPR_SENSITIVE_VALUE"))
#combined_df.show(5)

# Drop the 'temporaryClosures' column
combined_df = combined_df.drop('temporaryClosures')


# Save the end result as a Parquet file with partitioning
combined_df.write.partitionBy("province").parquet("brand_data.parquet", mode="overwrite")
print("Data is saved in the file")

# Stop SparkSession
# spark.stop()


INFO:__main__:Processing data for brand: clp...
INFO:__main__:Data processed for brand: clp
INFO:__main__:Processing data for brand: okay...
INFO:__main__:Data processed for brand: okay
INFO:__main__:Processing data for brand: spar...
INFO:__main__:Data processed for brand: spar
INFO:__main__:Processing data for brand: dats...
INFO:__main__:Data processed for brand: dats
INFO:__main__:Processing data for brand: cogo...
INFO:__main__:Data processed for brand: cogo


combined schema of all the brands - 
 StructType([StructField('address', StructType([StructField('cityName', StringType(), True), StructField('countryCode', StringType(), True), StructField('countryName', StringType(), True), StructField('houseNumber', StringType(), True), StructField('postalcode', StringType(), True), StructField('streetName', StringType(), True)]), True), StructField('branchId', StringType(), True), StructField('commercialName', StringType(), True), StructField('ensign', StructType([StructField('id', LongType(), True), StructField('name', StringType(), True)]), True), StructField('geoCoordinates', StructType([StructField('latitude', DoubleType(), True), StructField('longitude', DoubleType(), True)]), True), StructField('handoverServices', ArrayType(StringType(), True), True), StructField('isActive', BooleanType(), True), StructField('moreInfoUrl', StringType(), True), StructField('placeId', LongType(), True), StructField('placeSearchOpeningHours', ArrayType(StructTyp