In [6]:
import pyspark
import os
import logging
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from sqlalchemy import create_engine, text
from dotenv import load_dotenv
from datetime import datetime

In [7]:
# Load .env and define the credentials
load_dotenv(".env", override=True)

SOURCE_DB_HOST=os.getenv("SOURCE_DB_HOST")
SOURCE_DB_USER=os.getenv("SOURCE_DB_USER")
SOURCE_DB_PASS=os.getenv("SOURCE_DB_PASS")
SOURCE_DB_NAME=os.getenv("SOURCE_DB_NAME")
SOURCE_DB_PORT=os.getenv("SOURCE_DB_PORT")

STG_DB_HOST=os.getenv("STG_DB_HOST")
STG_DB_USER=os.getenv("STG_DB_USER")
STG_DB_PASS=os.getenv("STG_DB_PASS")
STG_DB_NAME=os.getenv("STG_DB_NAME")
STG_DB_PORT=os.getenv("STG_DB_PORT")

DWH_DB_HOST=os.getenv("DWH_DB_HOST")
DWH_DB_USER=os.getenv("DWH_DB_USER")
DWH_DB_PASS=os.getenv("DWH_DB_PASS")
DWH_DB_NAME=os.getenv("DWH_DB_NAME")
DWH_DB_PORT=os.getenv("DWH_DB_PORT")

LOG_DB_HOST=os.getenv("LOG_DB_HOST")
LOG_DB_USER=os.getenv("LOG_DB_USER")
LOG_DB_PASS=os.getenv("LOG_DB_PASS")
LOG_DB_NAME=os.getenv("LOG_DB_NAME")
LOG_DB_PORT=os.getenv("LOG_DB_PORT")

API_BASE_URL=os.getenv("API_BASE_URL")

MINIO_ENDPOINT=os.getenv("MINIO_ENDPOINT")
MINIO_ACCESS_KEY=os.getenv("MINIO_ACCESS_KEY")
MINIO_SECRET_KEY=os.getenv("MINIO_SECRET_KEY")


In [8]:
# Create spark session
spark = SparkSession \
        .builder \
        .appName("PySpark Project Dev") \
        .config("spark.hadoop.fs.s3a.endpoint", MINIO_ENDPOINT) \
        .config("spark.hadoop.fs.s3a.access.key", MINIO_ACCESS_KEY) \
        .config("spark.hadoop.fs.s3a.secret.key", MINIO_SECRET_KEY) \
        .config("spark.hadoop.fs.s3a.path.style.access", "true") \
        .config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") \
        .config("spark.hadoop.fs.s3a.connection.ssl.enabled", "false") \
        .getOrCreate()

spark

## Create database engine function

In [11]:
def source_engine():
    SOURCE_DB_URL = f"jdbc:postgresql://{SOURCE_DB_HOST}:{SOURCE_DB_PORT}/{SOURCE_DB_NAME}"
    return SOURCE_DB_URL, SOURCE_DB_USER, SOURCE_DB_PASS 

def staging_engine():
    STG_DB_URL = f"jdbc:postgresql://{STG_DB_HOST}:{STG_DB_PORT}/{STG_DB_NAME}"
    return STG_DB_URL, STG_DB_USER, STG_DB_PASS 
    
def staging_engine_sqlalchemy():
    return create_engine(f"postgresql://{STG_DB_USER}:{STG_DB_PASS}@{STG_DB_HOST}:{STG_DB_PORT}/{STG_DB_NAME}")

def dwh_engine_sqlalchemy():
    return create_engine(f"postgresql://{DWH_DB_USER}:{DWH_DB_PASS}@{DWH_DB_HOST}:{DWH_DB_PORT}/{DWH_DB_NAME}")

def dwh_engine():
    DWH_DB_URL = f"jdbc:postgresql://{DWH_DB_HOST}:{DWH_DB_PORT}/{DWH_DB_NAME}"
    return DWH_DB_URL, DWH_DB_USER, DWH_DB_PASS 

def log_engine():
    LOG_DB_URL = f"jdbc:postgresql://{LOG_DB_HOST}:{LOG_DB_PORT}/{LOG_DB_NAME}"
    return LOG_DB_URL, LOG_DB_USER, LOG_DB_PASS 

## Set Up Logging Function

In [12]:
def load_log_msg(spark: SparkSession, log_msg: pyspark.sql.DataFrame):

    LOG_DB_URL, LOG_DB_USER, LOG_DB_PASS = log_engine()
    table_name = "etl_log"

    # set config
    connection_properties = {
        "user": LOG_DB_USER,
        "password": LOG_DB_PASS,
        "driver": "org.postgresql.Driver"
    }

    log_msg.write.jdbc(url = LOG_DB_URL,
                  table = table_name,
                  mode = "append",
                  properties = connection_properties)

## Extract Database

In [6]:
def extract_database(spark: SparkSession, table_name: str):
    
    # Get source db config
    SOURCE_DB_URL, SOURCE_DB_USER, SOURCE_DB_PASS = source_engine()

    # Set config
    connection_properties = {
        "user" : SOURCE_DB_USER,
        "password" : SOURCE_DB_PASS,
        "driver" : "org.postgresql.Driver"
    }

    # Set current timestamp for logging
    current_timestamp = datetime.now()

    try:

        # Read data
        df = spark \
            .read \
            .jdbc(url=SOURCE_DB_URL,
                  table=table_name,
                  properties=connection_properties)

        print(f"Extraction process successful for table: {table_name}")

        # Set success log message
        log_message = spark.sparkContext \
            .parallelize([("sources", "extraction", "success", "source_db", table_name, current_timestamp)]) \
            .toDF(["step", "process", "status", "source", "table_name", "etl_date"])
    
        return df

    except Exception as e:
        print(f"Extraction process failed: {e}")

        # Set failed log message
        log_message = spark.sparkContext \
            .parallelize([("sources", "extraction", "failed", "source_db", table_name, current_timestamp, str(e))]) \
            .toDF(["step", "process", "status", "source", "table_name", "etl_date", "error_msg"])

    finally:
        load_log_msg(spark=spark, log_msg=log_message)

In [7]:
acquisition_df = extract_database(spark=spark, table_name="acquisition")
company_df = extract_database(spark=spark, table_name="company")
funding_rounds_df = extract_database(spark=spark, table_name="funding_rounds")
funds_df = extract_database(spark=spark, table_name="funds")
investments_df = extract_database(spark=spark, table_name="investments")
ipos_df = extract_database(spark=spark, table_name="ipos")

Extraction process successful for table: acquisition
Extraction process successful for table: company
Extraction process successful for table: funding_rounds
Extraction process successful for table: funds
Extraction process successful for table: investments
Extraction process successful for table: ipos


## Extract CSV

In [8]:
def extract_csv(spark: SparkSession, file_name: str):

    # Set csv file path
    path = "data/"
    
    # Set current timestamp for logging
    current_timestamp = datetime.now()

    try:
        
        # Read data
        df = spark.read.csv(path + file_name, header=True)

        print(f"Extraction process successful for file: {file_name}")

        # Set success log message
        log_message = spark.sparkContext \
            .parallelize([("sources", "extraction", "success", "csv", file_name, current_timestamp)]) \
            .toDF(["step", "process", "status", "source", "table_name", "etl_date"])

        return df

    except Exception as e:
        print(f"Extraction process failed: {e}")

        # Set failed log message
        log_message = spark.sparkContext \
            .parallelize([("sources", "extraction", "failed", "csv", file_name, current_timestamp, str(e))]) \
            .toDF(["step", "process", "status", "source", "table_name", "etl_date", "error_msg"])

    finally:
        load_log_msg(spark=spark, log_msg=log_message)
        

In [9]:
people_df = extract_csv(spark=spark, file_name="people.csv")
relationship_df = extract_csv(spark=spark, file_name="relationships.csv")

Extraction process successful for file: people.csv
Extraction process successful for file: relationships.csv


## Extract API

In [None]:
import requests

def extract_api(spark: SparkSession, start_date: str, end_date:str):

    # Set current timestamp for logging
    current_timestamp = datetime.now()

    # Define API url
    url = f"{API_BASE_URL}?start_date={start_date}&end_date={end_date}"
    
    try:
        response = requests.get(url)
        response.raise_for_status()
    
        data = response.json()
        
        if not data:
            print("There is no data in this range of date")
            df = None
        else:
            print(f"Extraction process successful for milestones table")
            df = spark.createDataFrame(data)
        
        # Set success log message
        log_message = spark.sparkContext \
            .parallelize([("sources", "extraction", "success", "api", "milestones", current_timestamp)]) \
            .toDF(["step", "process", "status", "source", "table_name", "etl_date"])
    
        
    except Exception as e:
        print(f"Extraction process failed: {e}")
        df = None

        # Set failed log message
        log_message = spark.sparkContext \
            .parallelize([("sources", "extraction", "failed", "api", "milestones", current_timestamp, str(e))]) \
            .toDF(["step", "process", "status", "source", "table_name", "etl_date", "error_msg"])

    finally:
        load_log_msg(spark=spark, log_msg=log_message)

    return df
        

In [11]:
milestones_df = extract_api(spark=spark, start_date="2014-01-01", end_date="2015-01-01")

Extraction process successful for milestones table


## Data Profiling

In [12]:
# # Check Percentage of Missing Values for each column with pyspark
# import pandas as pd
# import json
# from pyspark.sql.functions import col, count, when, round

# def check_missing_values(df):

#     total_data = df.count()

#     # Calculate the percentage of missing values for each column
#     get_missing_values = df.select([
#         round((count(when(col(column_name).isNull(), column_name)) / total_data) * 100, 2).alias(column_name)
#         for column_name in df.columns
#     ]).collect()[0].asDict()
    
#     return get_missing_values

In [13]:
# data_profiling_report = {
#     "Created by" : "Rico Febrian",
#     "Checking Date" : datetime.now().strftime('%d/%m/%y'),
#     "Column Information": {
#         "Acquisition": {"count": len(acquisition_df.columns), "columns": acquisition_df.columns},
#         "Company": {"count": len(company_df.columns), "columns": company_df.columns},
#         "Funding Rounds": {"count": len(funding_rounds_df.columns), "columns": funding_rounds_df.columns},
#         "Funds": {"count": len(funds_df.columns), "columns": funds_df.columns},
#         "Investments": {"count": len(investments_df.columns), "columns": investments_df.columns},
#         "IPOS": {"count": len(ipos_df.columns), "columns": ipos_df.columns},
#         "People": {"count": len(people_df.columns), "columns": people_df.columns},
#         "Relationships": {"count": len(relationship_df.columns), "columns": relationship_df.columns},
#         "Milestones": {"count": len(milestones_df.columns), "columns": milestones_df.columns}
#     },
#     "Check Data Size": {
#         "Acquisition": acquisition_df.count(),
#         "Company": company_df.count(),
#         "Funding Rounds": funding_rounds_df.count(),
#         "Funds": funds_df.count(),
#         "Investments": investments_df.count(),
#         "IPOS": ipos_df.count(),
#         "People": people_df.count(),
#         "Relationships": relationship_df.count(),
#         "Milestones": milestones_df.count()
#     },
#     "Data Type For Each Column" : {
#         "Acquisition": acquisition_df.dtypes,
#         "Company": company_df.dtypes,
#         "Funding Rounds": funding_rounds_df.dtypes,
#         "Funds": funds_df.dtypes,
#         "Investments": investments_df.dtypes,
#         "IPOS": ipos_df.dtypes,
#         "People": people_df.dtypes,
#         "Relationships": relationship_df.dtypes,
#         "Milestones": milestones_df.dtypes
#     },
#     "Check Missing Value" : {
#         "Acquisition": check_missing_values(acquisition_df),
#         "Company": check_missing_values(company_df),
#         "Funding Rounds": check_missing_values(funding_rounds_df),
#         "Funds": check_missing_values(funds_df),
#         "Investments": check_missing_values(investments_df),
#         "IPOS": check_missing_values(ipos_df),
#         "People": check_missing_values(people_df),
#         "Relationships": check_missing_values(relationship_df),
#         "Milestones": check_missing_values(milestones_df)
#     }
# }

# # Print dalam format JSON yang rapi
# # print(json.dumps(data_profiling_report, indent=4))

In [14]:
# # Create a function to save the final report to a JSON file
# def save_to_json(dict_result: dict, filename: str) -> None:
#     """
#     This function saves the data profiling result to a JSON file.

#     Args:
#         dict_result (dict): Data profiling result to save to a JSON file.
#         filename (str): Name of the JSON file to save the data profiling result to.

#     Returns:
#         None
#     """

#     try:
        
#         # Save the data profiling result to a JSON file
#         with open(f'{filename}.json', 'w') as file:
#             file.write(json.dumps(dict_result, indent= 4))
    
#     except Exception as e:
#         print(f"Error: {e}")

In [15]:
# save_to_json(dict_result=data_profiling_report, filename="data_profiling_report")

## Dump to Data Lake (MinIO)

In [16]:
# def load_to_minio(df, minio_path):

#     try:
#         df.write \
#           .mode("overwrite") \
#           .parquet(f"s3a://{minio_path}")

#         print(f"Initial load completed to {minio_path}")

#     except Exception as e:
#         print(f"Failed to load data: {e}")

In [17]:
# load_to_minio(df=acquisition_df, minio_path="raw-data/acquisition_table")
# load_to_minio(df=company_df, minio_path="raw-data/company_table")
# load_to_minio(df=funding_rounds_df, minio_path="raw-data/funding_rounds_table")
# load_to_minio(df=funds_df, minio_path="raw-data/funds_table")
# load_to_minio(df=investments_df, minio_path="raw-data/investments_table")
# load_to_minio(df=ipos_df, minio_path="raw-data/ipos_table")
# load_to_minio(df=people_df, minio_path="raw-data/people_table")
# load_to_minio(df=relationship_df, minio_path="raw-data/relationships_table")
# load_to_minio(df=milestones_df, minio_path="raw-data/milestones_table")

## Load to Staging

In [18]:
def load_to_stg(spark, df, table_name, source_name):

    current_timestamp = datetime.now()

    try:

        # Establish connection to staging db
        conn = staging_engine_sqlalchemy()

        with conn.begin() as connection:

            # Truncate all tables in data warehouse
            connection.execute(text(f"TRUNCATE TABLE {table_name} RESTART IDENTITY CASCADE"))

        print(f"Success truncating table: {table_name}")

    except Exception as e:
        print(f"Error when truncating table: {e}")

        log_message = spark.sparkContext\
            .parallelize([("staging", "load", "failed", source_name, table_name, current_timestamp, str(e))])\
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date', 'error_msg'])
        
        load_log_msg(spark=spark, log_msg=log_message) 

    finally:
        conn.dispose()

    # Load extarcted DataFrame to staging db
    try:
        
        # Get staging db config
        STG_DB_URL, STG_DB_USER, STG_DB_PASS = staging_engine()
    
        # Set config
        properties = {
            "user" : STG_DB_USER,
            "password" : STG_DB_PASS,
        }

        df.write.jdbc(url=STG_DB_URL,
                      table=table_name,
                      mode="append",
                      properties=properties)

        print(f"Load process successful for table: {table_name}")

        log_message = spark.sparkContext\
            .parallelize([("staging", "load", "success", source_name, table_name, current_timestamp)])\
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date'])
        
        load_log_msg(spark=spark, log_msg=log_message) 
        
    except Exception as e:
        print(f"Load process failed: {e}")

        log_message = spark.sparkContext\
            .parallelize([("staging", "load", "failed", source_name, table_name, current_timestamp, str(e))])\
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date', 'error_msg'])
        
    finally:
        load_log_msg(spark=spark, log_msg=log_message) 

In [19]:
load_to_stg(spark=spark, df=people_df, table_name="people", source_name="csv")
load_to_stg(spark=spark, df=relationship_df, table_name="relationships", source_name="csv")
load_to_stg(spark=spark, df=company_df, table_name="company", source_name="source_db")
load_to_stg(spark=spark, df=funding_rounds_df, table_name="funding_rounds", source_name="source_db")
load_to_stg(spark=spark, df=funds_df, table_name="funds", source_name="source_db")
load_to_stg(spark=spark, df=acquisition_df, table_name="acquisition", source_name="source_db")
load_to_stg(spark=spark, df=ipos_df, table_name="ipos", source_name="source_db")
load_to_stg(spark=spark, df=investments_df, table_name="investments", source_name="source_db")
load_to_stg(spark=spark, df=milestones_df, table_name="milestones", source_name="api")

Success truncating table: people
Load process successful for table: people
Success truncating table: relationships
Load process successful for table: relationships
Success truncating table: company
Load process successful for table: company
Success truncating table: funding_rounds
Load process successful for table: funding_rounds
Success truncating table: funds
Load process successful for table: funds
Success truncating table: acquisition
Load process successful for table: acquisition
Success truncating table: ipos
Load process successful for table: ipos
Success truncating table: investments
Load process successful for table: investments
Success truncating table: milestones
Load process successful for table: milestones


## Extract data from Staging DB

In [20]:
def extract_staging(spark: SparkSession, table_name: str):
    
    # Get staging db config
    STG_DB_URL, STG_DB_USER, STG_DB_PASS = staging_engine()

    # Set config
    connection_properties = {
        "user" : STG_DB_USER,
        "password" : STG_DB_PASS,
        "driver" : "org.postgresql.Driver"
    }

    # Set current timestamp for logging
    current_timestamp = datetime.now()

    try:

        # Read data
        df = spark \
            .read \
            .jdbc(url=STG_DB_URL,
                  table=table_name,
                  properties=connection_properties)

        print(f"Extraction process successful for table: {table_name}")

        # Set success log message
        log_message = spark.sparkContext \
            .parallelize([("staging", "extraction", "success", "source_db", table_name, current_timestamp)]) \
            .toDF(["step", "process", "status", "source", "table_name", "etl_date"])
    
        return df

    except Exception as e:
        print(f"Extraction process failed: {e}")

        # Set failed log message
        log_message = spark.sparkContext \
            .parallelize([("staging", "extraction", "failed", "source_db", table_name, current_timestamp, str(e))]) \
            .toDF(["step", "process", "status", "source", "table_name", "etl_date", "error_msg"])

    finally:
        load_log_msg(spark=spark, log_msg=log_message)

In [21]:
stg_acquisition = extract_staging(spark=spark, table_name="acquisition")
stg_company = extract_staging(spark=spark, table_name="company")
stg_funding_rounds = extract_staging(spark=spark, table_name="funding_rounds")
stg_funds = extract_staging(spark=spark, table_name="funds")
stg_investments = extract_staging(spark=spark, table_name="investments")
stg_ipos = extract_staging(spark=spark, table_name="ipos")
stg_people = extract_staging(spark=spark, table_name="people")
stg_relationships = extract_staging(spark=spark, table_name="relationships")
stg_milestones = extract_staging(spark=spark, table_name="milestones")

Extraction process successful for table: acquisition
Extraction process successful for table: company
Extraction process successful for table: funding_rounds
Extraction process successful for table: funds
Extraction process successful for table: investments
Extraction process successful for table: ipos
Extraction process successful for table: people
Extraction process successful for table: relationships
Extraction process successful for table: milestones


## Data Transformation

In [22]:
def clean_address(col_name: str):
    """
    Cleans address values in a DataFrame column through the following steps:
    1. Removing special characters '#' or '.' at the beginning of the string.
    2. Converting the entire string to lowercase for standardization.
    3. Identifying and replacing potentially invalid values with NULL.
       A value is considered invalid if it consists solely of symbols and/or numbers,
       or if its length after trimming leading and trailing spaces is less than or equal to 2 characters.

    Parameters:
        col_name (str): The name of the column containing the address values to clean.

    Returns:
        Column: A PySpark Column containing the cleaned address values.
                Invalid values will be replaced with NULL.
    """

    # Step 1: Convert to lowercase and remove '#' or '.' characters at the start of the string.
    # Example: '#Main St' becomes 'main st', '.Apartment 1A' becomes 'apartment 1a'
    cleaned = F.regexp_replace(F.lower(F.col(col_name)), r"^[#.]+", "")

    # Step 2: Define conditions to identify invalid values.

    # Condition 1: Check if the value (after step 1) consists solely of non-word characters
    #              (symbols, spaces, punctuation), digits (numbers), or underscores.
    #              Examples of values considered invalid: '??', '.323' (after removal becomes '323'), '------', ' !? '
    is_only_symbols = cleaned.rlike(r"^[\W\d_]+$")

    # Condition 2: Check if the length of the value (after step 1 and trimming) is too short.
    #              Values with a length of 2 characters or less after trimming are considered invalid.
    #              Examples of values considered invalid: 'a', ' b ', ''
    is_too_short = F.length(F.trim(cleaned)) <= 2

    # Step 3: Apply the cleaning logic.
    # If a value meets either of the invalid conditions (only symbols or too short),
    # then replace it with NULL. Otherwise, return the cleaned and trimmed value.
    cleaned_data = F.when(
        is_only_symbols | is_too_short,
        F.lit(None)  # Replace invalid values with NULL
    ).otherwise(
        F.trim(cleaned)  # Keep and trim valid values
    )

    return cleaned_data

In [23]:
def to_usd(currency_col, amount_col):

    exchange_rate = F.round(
        F.when(F.col(currency_col) == "USD", F.col(amount_col))
         .when(F.col(currency_col) == "CAD", F.col(amount_col) * 0.72)
         .when(F.col(currency_col) == "EUR", F.col(amount_col) * 1.14)
         .when(F.col(currency_col) == "SEK", F.col(amount_col) * 0.10)
         .when(F.col(currency_col) == "AUD", F.col(amount_col) * 0.64)
         .when(F.col(currency_col) == "JPY", F.col(amount_col) * 0.007)
         .when(F.col(currency_col) == "GBP", F.col(amount_col) * 1.33)
         .when(F.col(currency_col) == "NIS", F.col(amount_col) * 0.28)
         .otherwise(F.col(amount_col)),
        2
    )

    return exchange_rate

In [24]:
def transform_company(spark, df):
    """
    Transform raw company data into a cleaned and standardized dimension table.

    Args:
        spark: SparkSession
        df: Raw input DataFrame containing company data from staging db

    Returns:
        DataFrame: Cleaned and transformed company dimension table
    """

    # Set up current timestamp for logging
    current_timestamp = datetime.now()

    try:
        # Add a new column 'entity_type' to identify if the record is a 'company' or a 'fund'
        # based on the 'object_id' prefix.
        df = df.withColumn(
            "entity_type",
            F.when(F.col("object_id").startswith("c:"), "company")
             .when(F.col("object_id").startswith("f:"), "fund")
             .otherwise(None)
        )

        # Create cleaned versions of address columns by applying the 'clean_address' function.
        df = df.withColumn("address1_cleaned", clean_address(col_name="address1")) \
               .withColumn("address2_cleaned", clean_address(col_name="address2"))

        # Create the 'full_address' column by concatenating 'address1_cleaned' and 'address2_cleaned'.
        # It handles cases where one or both address columns are null or empty.
        df = df.withColumn(
            "full_address",
            F.when(
                (F.col("address1_cleaned").isNull()) & (F.col("address2_cleaned").isNull()),
                F.lit(None)
            ).when(
                (F.col("address1_cleaned").isNull()) | (F.col("address1_cleaned") == ""),
                F.col("address2_cleaned")
            ).when(
                (F.col("address2_cleaned").isNull()) | (F.col("address2_cleaned") == ""),
                F.col("address1_cleaned")
            ).otherwise(
                F.concat_ws(", ", F.col("address1_cleaned"), F.col("address2_cleaned"))
            )
        )

        # Standardize the values for 'region', 'city', and 'country_code' by:
        # 1. Trimming leading/trailing whitespace.
        # 2. Converting 'region' and 'city' to lowercase.
        # 3. Converting 'country_code' to uppercase.
        region_cleaned = F.trim(F.lower(F.col("region")))
        city_cleaned = F.trim(F.lower(F.col("city")))
        country_code_cleaned = F.trim(F.upper(F.col("country_code")))

        # Update the 'region', 'city', and 'country_code' columns with the cleaned values,
        # setting them to null if the cleaned value is null or empty.
        df = df.withColumn(
            "region",
            F.when(
                (region_cleaned.isNull()) | (region_cleaned == ""), F.lit(None)
            ).otherwise(region_cleaned)
        ).withColumn(
            "city",
            F.when(
                (city_cleaned.isNull()) | (city_cleaned == ""), F.lit(None)
            ).otherwise(city_cleaned)
        ).withColumn(
            "country_code",
            F.when(
                (country_code_cleaned.isNull()) | (country_code_cleaned == ""), F.lit(None)
            ).otherwise(country_code_cleaned)
        )

        # Select the necessary columns for the dimension table and rename 'object_id' to 'nk_company_id'
        # as the natural key.
        dim_company = df.select(
            F.col("object_id").alias("nk_company_id"),
            F.col("entity_type"),
            F.col("full_address"),
            F.col("region"),
            F.col("city"),
            F.col("country_code")
        )

        print("Transformation process successful for table: dim_company")

        # Log a success message with details about the ETL process.
        log_msg = spark.sparkContext \
            .parallelize([("warehouse", "transform", "success", "staging", "dim_company", current_timestamp)]) \
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date'])

        return dim_company

    except Exception as e:
        
        # Log an error message with details about the failure.
        print(e)
        log_msg = spark.sparkContext \
            .parallelize([("warehouse", "transform", "failed", "staging", "dim_company", current_timestamp, str(e))]) \
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date', 'error_msg'])
        

In [25]:
dim_company = transform_company(spark, df=stg_company)

Transformation process successful for table: dim_company


In [26]:
def clean_name(col_name: str):
    """

    Parameters:
        col_name (str): The name of the column containing the address values to clean.

    Returns:
        Column: A PySpark Column containing the cleaned address values.
                Invalid values will be replaced with NULL.
    """
    # Clean leading non-alphabetic characters like #. or digits
    cleaned = F.regexp_replace(F.lower(F.col(col_name)), r"^[^a-z]+", "")
    
    # Remove any remaining weird characters not allowed (keep only a-z, space, hyphen)
    cleaned = F.regexp_replace(cleaned, r"[^a-z\s\-]", "")
    
    # Condition 1: string is only symbols/numbers after cleaning
    is_only_symbols = cleaned.rlike(r"^[^a-z]+$")
    
    # Condition 2: too short
    is_too_short = F.length(F.trim(cleaned)) <= 2
    
    # Condition 3: starts with a number or special char (extra safe)
    starts_invalid = F.col(col_name).rlike(r"^\s*[\d\W_]+")
    
    # Final decision
    cleaned_data = F.when(
        is_only_symbols | is_too_short | starts_invalid,
        F.lit(None)
    ).otherwise(
        F.trim(cleaned)
    )
    
    return cleaned_data

In [28]:
from datetime import datetime
from pyspark.sql import functions as F

def transform_people(spark, df):
    """
    Transform raw people data into a cleaned and standardized dimension table.

    Args:
        spark: SparkSession
        df: Raw input DataFrame containing people data from staging db

    Returns:
        DataFrame: Cleaned and transformed people dimension table
    """

    # Set up current timestamp for logging
    current_timestamp = datetime.now()

    try:
        # Add cleaned versions of first and last name, using the clean_name function.
        df = df.withColumn("firstname_cleaned", clean_name(col_name="first_name")) \
               .withColumn("lastname_cleaned", clean_name(col_name="last_name"))

        # Create the 'full_name' column by concatenating cleaned first and last names.
        # Handles cases where one or both cleaned names are null or empty.
        df = df.withColumn(
            "full_name",
            F.when(
                (F.col("firstname_cleaned").isNull()) & (F.col("lastname_cleaned").isNull()),
                F.lit(None)
            ).when(
                (F.col("firstname_cleaned").isNull()) | (F.col("firstname_cleaned") == ""),
                F.col("lastname_cleaned")
            ).when(
                (F.col("lastname_cleaned").isNull()) | (F.col("lastname_cleaned") == ""),
                F.col("firstname_cleaned")
            ).otherwise(
                F.concat_ws(" ", F.col("firstname_cleaned"), F.col("lastname_cleaned"))
            )
        )

        # Standardize affiliation name by trimming whitespace and converting to lowercase.
        df = df.withColumn(
            "affiliation_name",
            F.trim(F.lower(F.col("affiliation_name")))
        )

        # Select the columns for the dimension table.
        # Rename 'object_id' to 'nk_people_id'.
        dim_people = df.select(
            F.col("object_id").alias("nk_people_id"),
            F.col("full_name"),
            F.col("affiliation_name")
        )

        print("Transformation process successful for table: dim_people")

        # Log a success message with details about the ETL process.
        log_msg = spark.sparkContext \
            .parallelize([("warehouse", "transform", "success", "staging", "dim_people", current_timestamp)]) \
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date'])

        return dim_people

    except Exception as e:
        
        # Log an error message with details about the failure.
        print(e)
        log_msg = spark.sparkContext \
            .parallelize([("warehouse", "transform", "failed", "staging", "dim_people", current_timestamp, str(e))]) \
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date', 'error_msg'])


In [29]:
dim_people = transform_people(spark, df=stg_people)

Transformation process successful for table: dim_people


In [30]:
def extract_warehouse(spark: SparkSession, table_name):
    # get config
    DWH_DB_URL, DWH_DB_USER, DWH_DB_PASS = dwh_engine()

    # set config
    connection_properties = {
        "user": DWH_DB_USER,
        "password": DWH_DB_PASS,
        "driver": "org.postgresql.Driver" # set driver postgres
    }
    
    try:
        # read data
        df = spark \
                .read \
                .jdbc(url = DWH_DB_URL,
                        table = table_name,
                        properties = connection_properties)
        return df
    except Exception as e:
        print(e)

In [35]:
def transform_funds(spark, df):
    """
    Transform raw funds data into a cleaned and standardized dimension table.

    Args:
        spark: SparkSession
        df: Raw input DataFrame containing fund data from staging db

    Returns:
        DataFrame: Cleaned and transformed funds dimension table
    """

    # Set up current timestamp for logging
    current_timestamp = datetime.now()

    try:
        # Extract the dim_date dimension table.
        dim_date = extract_warehouse(spark, table_name="dim_date")

        # Standardize 'name' and 'source_description' by trimming whitespace and converting to lowercase.
        df = df.withColumn("name", F.trim(F.lower(F.col("name")))) \
               .withColumn("source_description", F.trim(F.lower(F.col("source_description"))))

        # Convert raised amount to USD using the to_usd function.
        df = df.withColumn("raised_amount_usd", to_usd(currency_col="raised_currency_code", amount_col="raised_amount"))

        # Add a foreign key 'funded_date_id' by formatting 'funded_at' to match the 'date_id' in dim_date.
        df = df.withColumn(
            "funded_date_id",
            F.date_format(df.funded_at, "yyyyMMdd").cast("integer")
        )

        # Join with dim_date to get date information based on 'funded_date_id'.
        df = df.join(
            dim_date,
            df.funded_date_id == dim_date.date_id,
            "left"
        )

        # Remove empty strings from 'fund_description' and set them to NULL.
        df = df.withColumn(
            "source_description",
            F.when(F.trim(df.source_description) == "", None)
              .otherwise(df.source_description)
        )

        # Select the columns for the dimension table and rename for clarity.
        dim_fund = df.select(
            F.col("object_id").alias("nk_fund_id"),
            F.col("name").alias("fund_name"),
            F.col("raised_amount_usd"),
            F.col("funded_date_id").alias("funded_at"),
            F.col("source_description").alias("fund_description")
        )

        print("Transformation process successful for table: dim_fund")

        # Log a success message with details about the ETL process.
        log_msg = spark.sparkContext \
            .parallelize([("warehouse", "transform", "success", "staging", "dim_fund", current_timestamp)]) \
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date'])

        return dim_fund

    except Exception as e:
        
        # Log an error message with details about the failure.
        print(e)
        log_msg = spark.sparkContext \
            .parallelize([("warehouse", "transform", "failed", "staging", "dim_fund", current_timestamp, str(e))]) \
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date', 'error_msg'])


In [36]:
dim_funds = transform_funds(spark, df=stg_funds)

Transformation process successful for table: dim_fund


In [37]:
def transform_relationship(spark, df):
    """
    Transform raw relationships data into a cleaned and standardized dimension table.

    Args:
        spark: SparkSession
        df: Raw input DataFrame containing relationships data from staging db

    Returns:
        DataFrame: Cleaned and transformed bridge people company dimension table
    """

    # Set up current timestamp for logging
    current_timestamp = datetime.now()

    try:
        # Extract dimension tables needed for joins.
        dim_date = extract_warehouse(spark, table_name="dim_date")
        dim_company = extract_warehouse(spark, table_name="dim_company")
        dim_people = extract_warehouse(spark, table_name="dim_people")

        # Join with dim_company to get the company's surrogate key (sk_company_id).
        df = df.join(
            dim_company.select("sk_company_id", "nk_company_id"),
            df.relationship_object_id == dim_company.nk_company_id,
            "inner"
        )

        # Join with dim_people to get the person's surrogate key (sk_people_id).
        df = df.join(
            dim_people.select("sk_people_id", "nk_people_id"),
            df.person_object_id == dim_people.nk_people_id,
            "inner"
        )

        # Convert start and end dates to integer format (yyyyMMdd) for consistency.
        df = df.withColumn("relationship_start_at", F.date_format("start_at", "yyyyMMdd").cast("integer")) \
               .withColumn("relationship_end_at", F.date_format("end_at", "yyyyMMdd").cast("integer"))

        # Alias the dim_date table for joining on start and end dates.
        dim_date_start = dim_date.alias("start_date")
        dim_date_end = dim_date.alias("end_date")

        # Join with dim_date to get date information for start and end dates.
        df = df.join(
            dim_date_start,
            df.relationship_start_at == F.col("start_date.date_id"),
            "inner"
        ).join(
            dim_date_end,
            df.relationship_end_at == F.col("end_date.date_id"),
            "inner"
        )

        # Clean the 'title' column by trimming whitespace, converting to lowercase, and setting '.' to NULL.
        df = df.withColumn(
            "title",
            F.when(F.col("title") == ".", F.lit(None))
             .otherwise(F.trim(F.lower(F.col("title"))))
        )

        # Select the columns for the bridge table.
        bridge_company_people = df.select(
            F.col("sk_company_id"),
            F.col("sk_people_id"),
            F.col("title"),
            F.col("is_past"),
            F.col("relationship_start_at"),
            F.col("relationship_end_at")
        )

        print("Transformation process successful for table: bridge_company_people")

        # Log a success message for the ETL process.
        log_msg = spark.sparkContext \
            .parallelize([("warehouse", "transform", "success", "staging", "bridge_company_people", current_timestamp)]) \
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date'])

        return bridge_company_people

    except Exception as e:
        
        # Log an error message with details about the failure.
        print(e)
        log_msg = spark.sparkContext \
            .parallelize([("warehouse", "transform", "failed", "staging", "bridge_company_people", current_timestamp, str(e))]) \
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date', 'error_msg'])


In [38]:
bridge_company_people = transform_relationship(spark, df=stg_relationships)

Transformation process successful for table: bridge_company_people


In [41]:
def transform_investments(spark, df):
    """
    Transform raw investments and funding rounds data into a cleaned and standardized fact table.

    Args:
        spark: SparkSession
        df: Raw input DataFrame containing investments data from staging db

    Returns:
        DataFrame: Cleaned and transformed investments fact table
    """

    # Set up current timestamp for logging
    current_timestamp = datetime.now()

    try:
        # Extract dimension tables needed for joins.
        dim_date = extract_warehouse(spark, table_name="dim_date")
        dim_company = extract_warehouse(spark, table_name="dim_company")
        dim_funds = extract_warehouse(spark, table_name="dim_funds")
        stg_funding_rounds = extract_staging(spark, table_name="funding_rounds")

        # Join with dim_company to get the company's surrogate key (sk_company_id).
        df = df.join(
            dim_company.select("sk_company_id", "nk_company_id"),
            df.funded_object_id == dim_company.nk_company_id,
            "inner"
        )

        # Join with dim_fund to get the fund's surrogate key (sk_fund_id).
        df = df.join(
            dim_funds.select("sk_fund_id", "nk_fund_id"),
            df.investor_object_id == dim_funds.nk_fund_id,
            "inner"
        )

        # Prepare staging funding rounds data.
        stg_funding_rounds = stg_funding_rounds.withColumn(
            "funded_at",
            F.date_format("funded_at", "yyyyMMdd").cast("integer")
        )

        # Join with dim_date to get date_id.
        stg_funding_rounds = stg_funding_rounds.join(
            dim_date.select("date_id"),
            stg_funding_rounds.funded_at == dim_date.date_id,
            "inner"
        )

        # Join with the funding rounds staging table to get additional information.
        df = df.join(
            stg_funding_rounds.select(
                "funding_round_id", "funding_round_type", "participants",
                "raised_amount_usd", "raised_currency_code",
                "pre_money_valuation_usd", "post_money_valuation_usd",
                "funded_at"
            ),
            on="funding_round_id",
            how="left"
        )

        # Select the columns for the fact table and rename for clarity.
        fct_investments = df.select(
            F.col("investment_id").alias("dd_investment_id"),
            F.col("sk_company_id"),
            F.col("sk_fund_id"),
            F.col("funded_at"),
            F.col("funding_round_type"),
            F.col("participants").alias("num_of_participants"),
            F.col("raised_amount_usd"),
            F.col("pre_money_valuation_usd"),
            F.col("post_money_valuation_usd"),
        )

        print("Transformation process successful for table: fct_investments")

        # Log a success message for the ETL process.
        log_msg = spark.sparkContext \
            .parallelize([("warehouse", "transform", "success", "staging", "fct_investments", current_timestamp)]) \
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date'])

        return fct_investments

    except Exception as e:
        
        # Log an error message with details about the failure.
        print(e)
        log_msg = spark.sparkContext \
            .parallelize([("warehouse", "transform", "failed", "staging", "fct_investments", current_timestamp, str(e))]) \
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date', 'error_msg'])


In [42]:
fct_investments = transform_investments(spark, df=stg_investments)

Extraction process successful for table: funding_rounds
Transformation process successful for table: fct_investments


In [44]:
def transform_ipos(spark, df):
    """
    Transform raw IPOs data into a cleaned and standardized fact table.

    Args:
        spark: SparkSession
        df: Raw input DataFrame containing ipos data from staging db

    Returns:
        DataFrame: Cleaned and transformed ipos fact table
    """
    
    # Set up current timestamp for logging
    current_timestamp = datetime.now()

    try:
        # Extract dimension tables needed for joins.
        dim_date = extract_warehouse(spark, table_name="dim_date")
        dim_company = extract_warehouse(spark, table_name="dim_company")

        # Cast 'ipo_id' to integer.
        df = df.withColumn("ipo_id", F.col("ipo_id").cast("integer"))

        # Join with dim_company to get the company's surrogate key (sk_company_id).
        df = df.join(
            dim_company.select("sk_company_id", "nk_company_id"),
            df.object_id == dim_company.nk_company_id,
            "inner"
        )

        # Add a foreign key 'public_date_id' by formatting 'public_at' to match 'date_id' in dim_date.
        df = df.withColumn(
            "public_date_id",
            F.date_format(df.public_at, "yyyyMMdd").cast("integer")
        )

        # Join with dim_date to get date information based on 'public_date_id'.
        df = df.join(
            dim_date,
            df.public_date_id == dim_date.date_id,
            "left"
        )

        # Convert valuation and raised amounts to USD.
        df = df.withColumn("valuation_amount_usd", to_usd(currency_col="valuation_currency_code", amount_col="valuation_amount"))
        df = df.withColumn("raised_amount_usd", to_usd(currency_col="raised_currency_code", amount_col="raised_amount"))

        # Clean and normalize the stock symbol.
        cleaned_stock_symbol = F.trim(F.lower(F.col("stock_symbol")))
        invalid_symbol = cleaned_stock_symbol.rlike(r"^[\W\d_]+$")  # Identify invalid symbols
        cleaned_data = F.when(invalid_symbol, F.lit(None)).otherwise(cleaned_stock_symbol) # Replace invalid with NULL
        df = df.withColumn("stock_symbol", cleaned_data)

        # Remove unused whitespace and convert values to lowercase.
        df = df.withColumn("source_description", F.trim(F.lower(F.col("source_description"))))

        # Select the columns for the fact table and rename for clarity.
        fct_ipos = df.select(
            F.col("ipo_id").alias("dd_ipo_id"),
            F.col("sk_company_id"),
            F.col("valuation_amount_usd"),
            F.col("raised_amount_usd"),
            F.col("public_date_id").alias("public_at"),
            F.col("stock_symbol"),
            F.col("source_description").alias("ipo_description")
        )

        print("Transformation process successful for table: fct_ipos")

        # Log a success message for the ETL process.
        log_msg = spark.sparkContext \
            .parallelize([("warehouse", "transform", "success", "staging", "fct_ipos", current_timestamp)]) \
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date'])

        return fct_ipos

    except Exception as e:
        
        # Log an error message with details about the failure.
        print(e)
        log_msg = spark.sparkContext \
            .parallelize([("warehouse", "transform", "failed", "staging", "fct_ipos", current_timestamp, str(e))]) \
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date', 'error_msg'])


In [45]:
fct_ipos = transform_ipos(spark, df=stg_ipos)

Transformation process successful for table: fct_ipos


In [46]:
def transform_acquisition(spark, df):
    """
    Transform raw acquisition data into a cleaned and standardized fact table.

    Args:
        spark: SparkSession
        df: Raw input DataFrame containing acquisition data from staging db

    Returns:
        DataFrame: Cleaned and transformed acquisition fact table
    """
    
    # Set up current timestamp for logging
    current_timestamp = datetime.now()

    try:
        # Extract dimension tables needed for joins.
        dim_date = extract_warehouse(spark, table_name="dim_date")
        dim_company = extract_warehouse(spark, table_name="dim_company")

        # Set alias for acquiring and acquired companies from dim_company.
        dim_company_acquiring = dim_company.alias("acq")
        dim_company_acquired = dim_company.alias("acd")

        # Join with dim_company to get surrogate keys for acquiring and acquired companies.
        df = df.join(
            dim_company_acquiring.select(
                F.col("sk_company_id").alias("sk_acquiring_company_id"),
                F.col("nk_company_id").alias("nk_acquiring_company_id")
            ),
            df.acquiring_object_id == F.col("nk_acquiring_company_id"),
            "inner"
        )

        df = df.join(
            dim_company_acquired.select(
                F.col("sk_company_id").alias("sk_acquired_company_id"),
                F.col("nk_company_id").alias("nk_acquired_company_id")
            ),
            df.acquired_object_id == F.col("nk_acquired_company_id"),
            "inner"
        )

        # Add a foreign key 'acquired_date_id' by formatting 'acquired_at' to match 'date_id' in dim_date.
        df = df.withColumn(
            "acquired_date_id",
            F.date_format(df.acquired_at, "yyyyMMdd").cast("integer")
        )

        # Join with dim_date to get date information based on 'acquired_date_id'.
        df = df.join(
            dim_date,
            df.acquired_date_id == dim_date.date_id,
            "left"
        )

        # Convert acquisition price to USD.
        df = df.withColumn("price_amount_usd", to_usd(currency_col="price_currency_code", amount_col="price_amount"))

        # Clean and normalize term code.
        cleaned_term_code = F.trim(F.lower(F.col("term_code")))
        df = df.withColumn("term_code", F.when(cleaned_term_code == "", F.lit(None)).otherwise(cleaned_term_code))

        # Clean the source description.
        cleaned_description = F.trim(F.lower(F.col("source_description")))
        df = df.withColumn(
            "source_description",
            F.when(cleaned_description == "", F.lit(None))
             .otherwise(cleaned_description)
        )

        # Select the columns for the fact table and rename for clarity.
        fct_acquisition = df.select(
            F.col("acquisition_id").alias("dd_acquisition_id"),
            F.col("sk_acquiring_company_id"),
            F.col("sk_acquired_company_id"),
            F.col("price_amount_usd"),
            F.col("acquired_date_id").alias("acquired_at"),
            F.col("term_code"),
            F.col("source_description").alias("acquisition_description")
        )

        print("Transformation process successful for table: fct_acquisition")

        # Log a success message for the ETL process.
        log_msg = spark.sparkContext \
            .parallelize([("warehouse", "transform", "success", "staging", "fct_acquisition", current_timestamp)]) \
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date'])

        return fct_acquisition

    except Exception as e:
        
        # Log an error message with details about the failure.
        print(e)
        log_msg = spark.sparkContext \
            .parallelize([("warehouse", "transform", "failed", "staging", "fct_acquisition", current_timestamp, str(e))]) \
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date', 'error_msg'])


In [47]:
fct_acquisition = transform_acquisition(spark, df=stg_acquisition)

Transformation process successful for table: fct_acquisition


## Load to Warehouse

In [48]:
def load_to_dwh(spark, df, table_name, source_name):

    current_timestamp = datetime.now()

    try:

        # Establish connection to warehouse db
        conn = dwh_engine_sqlalchemy()

        with conn.begin() as connection:

            # Truncate all tables in data warehouse
            connection.execute(text(f"TRUNCATE TABLE {table_name} RESTART IDENTITY CASCADE"))

        print(f"Success truncating table: {table_name}")

    except Exception as e:
        print(f"Error when truncating table: {e}")

        log_message = spark.sparkContext\
            .parallelize([("warehouse", "load", "failed", source_name, table_name, current_timestamp, str(e))])\
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date', 'error_msg'])
        
        load_log_msg(spark=spark, log_msg=log_message) 

    finally:
        conn.dispose()

    # Load extarcted DataFrame to warehouse db
    try:
        
        # Get warehouse db config
        DWH_DB_URL, DWH_DB_USER, DWH_DB_PASS = dwh_engine()
    
        # Set config
        properties = {
            "user" : DWH_DB_USER,
            "password" : DWH_DB_PASS,
        }

        df.write.jdbc(url=DWH_DB_URL,
                      table=table_name,
                      mode="append",
                      properties=properties)

        print(f"Load process successful for table: {table_name}")

        log_message = spark.sparkContext\
            .parallelize([("warehouse", "load", "success", source_name, table_name, current_timestamp)])\
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date'])
        
        load_log_msg(spark=spark, log_msg=log_message) 
        
    except Exception as e:
        print(f"Load process failed: {e}")

        log_message = spark.sparkContext\
            .parallelize([("warehouse", "load", "failed", source_name, table_name, current_timestamp, str(e))])\
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date', 'error_msg'])
        
    finally:
        load_log_msg(spark=spark, log_msg=log_message) 

In [49]:
load_to_dwh(spark, df=dim_company ,table_name="dim_company", source_name="staging_db")

Success truncating table: dim_company
Load process successful for table: dim_company


In [50]:
load_to_dwh(spark, df=dim_people, table_name="dim_people", source_name="staging_db")

Success truncating table: dim_people
Load process successful for table: dim_people


In [52]:
load_to_dwh(spark, df=dim_funds, table_name="dim_funds", source_name="staging_db")

Success truncating table: dim_funds
Load process successful for table: dim_funds


In [53]:
load_to_dwh(spark, df=bridge_company_people, table_name="bridge_company_people", source_name="staging_db")

Success truncating table: bridge_company_people
Load process successful for table: bridge_company_people


In [54]:
load_to_dwh(spark, df=fct_investments, table_name="fct_investments", source_name="staging_db")

Success truncating table: fct_investments
Load process successful for table: fct_investments


In [55]:
load_to_dwh(spark, df=fct_ipos, table_name="fct_ipos", source_name="staging_db")

Success truncating table: fct_ipos
Load process successful for table: fct_ipos


In [56]:
load_to_dwh(spark, df=fct_acquisition, table_name="fct_acquisition", source_name="staging_db")

Success truncating table: fct_acquisition
Load process successful for table: fct_acquisition
