<a href="https://colab.research.google.com/github/aks-vijay/Apache-Spark/blob/main/Ingest%2C_Wrangle_%26_Export_F1_Data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [139]:
from typing import List
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, DateType
from pyspark.sql.functions import current_timestamp, col, lit, concat
import re

class ProccessData:
    def __init__(self, datapath, schema) -> None:
        self.spark = SparkSession.builder \
            .appName("MyLocalSparkApp") \
            .master("local[*]") \
            .getOrCreate()
        self.df = self.spark.read \
            .schema(schema=schema) \
            .option("header", True) \
            .csv(datapath) \
            .withColumn("ingestion_date", current_timestamp())

      # create mapping for country codes

        self.country_codes = {
            "Russia": "RU",
            "Sweden": "SE",
            "Malaysia": "MY",
            "Singapore": "SG",
            "Turkey": "TR",
            "Germany": "DE",
            "France": "FR",
            "Argentina": "AR",
            "Belgium": "BE",
            "China": "CN",
            "India": "IN",
            "Italy": "IT",
            "Spain": "ES",
            "Monaco": "MC",
            "Morocco": "MA",
            "USA": "US",
            "Mexico": "MX",
            "Azerbaijan": "AZ",
            "UK": "UK",
            "Saudi Arabia": "SA"
        }
        self.country_codes_lists = list(self.country_codes.items())
        self.country_codes_lookup = self.spark.createDataFrame(self.country_codes_lists, schema=["country", "country_codes"])

    def extract_all_columns (self, df) -> List:
      return [column for column, datatype in df.dtypes]

    def extract_string_columns(self, df) -> List:
        return [column for column, datatype in df.dtypes if datatype == "string"]

    def extract_integer_columns(self, df) -> List:
        return [column for column, datatype in df.dtypes if datatype == "int"]

    def camel_to_snake(self, column):
        return re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', column).lower()

    def extract_columns_and_convert_to_snake(self):
        for col in self.df.columns:
            new_col_name = self.camel_to_snake(col)
            self.df = self.df.withColumnRenamed(col, new_col_name)
        return self.df

    def write_to_file(self, df, datapath):
        export_file_name = datapath.replace(".csv", ".parquet")
        df.write \
            .mode("overwrite") \
            .parquet(export_file_name)

  and should_run_async(code)



In [148]:
# define user defined schema
circuits_user_defined_schema = StructType(fields=
    [
        StructField('circuitId', IntegerType(), False),
        StructField('circuitRef', StringType(), True),
        StructField('name', StringType(), True),
        StructField('location', StringType(), True),
        StructField('country', StringType(), True),
        StructField('lat', DoubleType(), True),
        StructField('lng', DoubleType(), True),
        StructField('alt', IntegerType(), True),
        StructField('url', StringType(), True)
    ]
)

races_user_defined_schema = StructType(fields=
                                        [StructField('raceId', IntegerType(), False),
                                          StructField('year', IntegerType(), True),
                                          StructField('round', IntegerType(), True),
                                          StructField('circuitId', IntegerType(), False),
                                          StructField('name', StringType(), True),
                                          StructField('date', DateType(), True),
                                          StructField('time', StringType(), True),
                                          StructField('url', StringType(), True)]
                                       )


circuits_input_file_path = "/content/circuits.csv"
races_input_file_path = "/content/races.csv"

# initialize  process to Spark dataframe
circuits_data_processor = ProccessData(
                                        datapath=circuits_input_file_path,
                                        schema = circuits_user_defined_schema
                                      )

races_data_processor = ProccessData(
                                      datapath = races_input_file_path,
                                      schema = races_user_defined_schema
                                  )

# data cleaning - rename the columns -> convert to snake case
df_renamed_circuits = circuits_data_processor.extract_columns_and_convert_to_snake()
df_renamed_races = races_data_processor.extract_columns_and_convert_to_snake()

# clean the dataframes
df_races_cleaned = df_renamed_races \
                    .withColumn("race_date_time", concat(col("date"), lit(" "), col("time"))) \
                    .drop("date", col("time"))

columns_to_select = ["circuit_id", "circuit_ref", "name", "location", "country", "location_and_country", "lat", "lng", "alt", "url"]

df_circuits_cleaned = df_renamed_circuits \
                        .withColumn("location_and_country", concat(col("location"), lit(", "), col("country"))) \
                        .join(circuits_data_processor.country_codes_lookup, df_renamed_circuits.country==circuits_data_processor.country_codes_lookup.country, "inner") \
                        .drop(df_renamed_circuits["country"], circuits_data_processor.country_codes_lookup["country"]) \
                        .withColumnRenamed("country_codes", "country") \
                        .select(*columns_to_select)

df_races_circuits_cleaned = df_races_cleaned \
                              .join(df_circuits_cleaned, df_races_cleaned.circuit_id == df_circuits_cleaned.circuit_id, "inner") \
                              .select(df_races_cleaned.race_id,
                                      df_races_cleaned.year.alias("race_year"),
                                      df_races_cleaned.round,
                                      df_races_cleaned.circuit_id,
                                      df_races_cleaned.name.alias("race_name"),
                                      df_races_cleaned.race_date_time,
                                      df_circuits_cleaned.name.alias("circuit_name"),
                                      df_circuits_cleaned.location.alias("circuit_location"),
                                      df_circuits_cleaned.country.alias("circuit_country"),
                                      df_circuits_cleaned.lat.alias("latitude"),
                                      df_circuits_cleaned.lng.alias("longitude"),
                                      df_races_cleaned.url.alias("race_url"),
                                      df_circuits_cleaned.url.alias("circuit_url"),
                                      df_races_cleaned.ingestion_date)

# export to parquet
circuits_data_processor.write_to_file(df_races_circuits_cleaned, circuits_input_file_path)