In [1]:
import pyspark
import os
import logging
from pyspark.sql import SparkSession
from sqlalchemy import create_engine, text
from dotenv import load_dotenv
from datetime import datetime

In [2]:
# Load .env and define the credentials
load_dotenv(".env", override=True)

SOURCE_DB_HOST=os.getenv("SOURCE_DB_HOST")
SOURCE_DB_USER=os.getenv("SOURCE_DB_USER")
SOURCE_DB_PASS=os.getenv("SOURCE_DB_PASS")
SOURCE_DB_NAME=os.getenv("SOURCE_DB_NAME")
SOURCE_DB_PORT=os.getenv("SOURCE_DB_PORT")

STG_DB_HOST=os.getenv("STG_DB_HOST")
STG_DB_USER=os.getenv("STG_DB_USER")
STG_DB_PASS=os.getenv("STG_DB_PASS")
STG_DB_NAME=os.getenv("STG_DB_NAME")
STG_DB_PORT=os.getenv("STG_DB_PORT")

DWH_DB_HOST=os.getenv("DWH_DB_HOST")
DWH_DB_USER=os.getenv("DWH_DB_USER")
DWH_DB_PASS=os.getenv("DWH_DB_PASS")
DWH_DB_NAME=os.getenv("DWH_DB_NAME")
DWH_DB_PORT=os.getenv("DWH_DB_PORT")

LOG_DB_HOST=os.getenv("LOG_DB_HOST")
LOG_DB_USER=os.getenv("LOG_DB_USER")
LOG_DB_PASS=os.getenv("LOG_DB_PASS")
LOG_DB_NAME=os.getenv("LOG_DB_NAME")
LOG_DB_PORT=os.getenv("LOG_DB_PORT")

MINIO_ENDPOINT=os.getenv("MINIO_ENDPOINT")
MINIO_ACCESS_KEY=os.getenv("MINIO_ACCESS_KEY")
MINIO_SECRET_KEY=os.getenv("MINIO_SECRET_KEY")


In [3]:
# Create spark session
spark = SparkSession \
        .builder \
        .appName("PySpark Project Dev") \
        .config("spark.hadoop.fs.s3a.endpoint", MINIO_ENDPOINT) \
        .config("spark.hadoop.fs.s3a.access.key", MINIO_ACCESS_KEY) \
        .config("spark.hadoop.fs.s3a.secret.key", MINIO_SECRET_KEY) \
        .config("spark.hadoop.fs.s3a.path.style.access", "true") \
        .config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") \
        .config("spark.hadoop.fs.s3a.connection.ssl.enabled", "false") \
        .getOrCreate()

spark

## Create database engine function

In [4]:
def source_engine():
    SOURCE_DB_URL = f"jdbc:postgresql://{SOURCE_DB_HOST}:{SOURCE_DB_PORT}/{SOURCE_DB_NAME}"
    return SOURCE_DB_URL, SOURCE_DB_USER, SOURCE_DB_PASS 

def staging_engine():
    STG_DB_URL = f"jdbc:postgresql://{STG_DB_HOST}:{STG_DB_PORT}/{STG_DB_NAME}"
    return STG_DB_URL, STG_DB_USER, STG_DB_PASS 
    
def staging_engine_sqlalchemy():
    return create_engine(f"postgresql://{STG_DB_USER}:{STG_DB_PASS}@{STG_DB_HOST}:{STG_DB_PORT}/{STG_DB_NAME}")

def dwh_engine_sqlalchemy():
    return create_engine(f"postgresql://{DWH_DB_USER}:{DWH_DB_PASS}@{DWH_DB_HOST}:{DWH_DB_PORT}/{DWH_DB_NAME}")

def dwh_engine():
    DWH_DB_URL = f"jdbc:postgresql://{DWH_DB_HOST}:{DWH_DB_PORT}/{DWH_DB_NAME}"
    return DWH_DB_URL, DWH_DB_USER, DWH_DB_PASS 

def log_engine():
    LOG_DB_URL = f"jdbc:postgresql://{LOG_DB_HOST}:{LOG_DB_PORT}/{LOG_DB_NAME}"
    return LOG_DB_URL, LOG_DB_USER, LOG_DB_PASS 

## Set Up Logging Function

In [5]:
def load_log_msg(spark: SparkSession, log_msg):

    LOG_DB_URL, LOG_DB_USER, LOG_DB_PASS = log_engine()
    table_name = "etl_log"

    # set config
    connection_properties = {
        "user": LOG_DB_USER,
        "password": LOG_DB_PASS,
        "driver": "org.postgresql.Driver"
    }

    log_msg.write.jdbc(url = LOG_DB_URL,
                  table = table_name,
                  mode = "append",
                  properties = connection_properties)

## Extract Database

In [6]:
def extract_database(spark: SparkSession, table_name: str):
    
    # Get source db config
    SOURCE_DB_URL, SOURCE_DB_USER, SOURCE_DB_PASS = source_engine()

    # Set config
    connection_properties = {
        "user" : SOURCE_DB_USER,
        "password" : SOURCE_DB_PASS,
        "driver" : "org.postgresql.Driver"
    }

    # Set current timestamp for logging
    current_timestamp = datetime.now()

    try:

        # Read data
        df = spark \
            .read \
            .jdbc(url=SOURCE_DB_URL,
                  table=table_name,
                  properties=connection_properties)

        print(f"Extraction process successful for table: {table_name}")

        # Set success log message
        log_message = spark.sparkContext \
            .parallelize([("sources", "extraction", "success", "source_db", table_name, current_timestamp)]) \
            .toDF(["step", "process", "status", "source", "table_name", "etl_date"])
    
        return df

    except Exception as e:
        print(f"Extraction process failed: {e}")

        # Set failed log message
        log_message = spark.sparkContext \
            .parallelize([("sources", "extraction", "failed", "source_db", table_name, current_timestamp, str(e))]) \
            .toDF(["step", "process", "status", "source", "table_name", "etl_date", "error_msg"])

    finally:
        load_log_msg(spark=spark, log_msg=log_message)

In [7]:
acquisition_df = extract_database(spark=spark, table_name="acquisition")
company_df = extract_database(spark=spark, table_name="company")
funding_rounds_df = extract_database(spark=spark, table_name="funding_rounds")
funds_df = extract_database(spark=spark, table_name="funds")
investments_df = extract_database(spark=spark, table_name="investments")
ipos_df = extract_database(spark=spark, table_name="ipos")

Extraction process successful for table: acquisition
Extraction process successful for table: company
Extraction process successful for table: funding_rounds
Extraction process successful for table: funds
Extraction process successful for table: investments
Extraction process successful for table: ipos


## Extract CSV

In [8]:
def extract_csv(spark: SparkSession, file_name: str):

    # Set csv file path
    path = "data/"
    
    # Set current timestamp for logging
    current_timestamp = datetime.now()

    try:
        
        # Read data
        df = spark.read.csv(path + file_name, header=True)

        print(f"Extraction process successful for file: {file_name}")

        # Set success log message
        log_message = spark.sparkContext \
            .parallelize([("sources", "extraction", "success", "csv", file_name, current_timestamp)]) \
            .toDF(["step", "process", "status", "source", "table_name", "etl_date"])

        return df

    except Exception as e:
        print(f"Extraction process failed: {e}")

        # Set failed log message
        log_message = spark.sparkContext \
            .parallelize([("sources", "extraction", "failed", "csv", file_name, current_timestamp, str(e))]) \
            .toDF(["step", "process", "status", "source", "table_name", "etl_date", "error_msg"])

    finally:
        load_log_msg(spark=spark, log_msg=log_message)
        

In [9]:
people_df = extract_csv(spark=spark, file_name="people.csv")
relationship_df = extract_csv(spark=spark, file_name="relationships.csv")

Extraction process successful for file: people.csv
Extraction process successful for file: relationships.csv


## Extract API

In [10]:
import requests

def extract_api(spark: SparkSession, start_date: str, end_date:str):

    # Set current timestamp for logging
    current_timestamp = datetime.now()

    # Define API url
    url = f"https://api-milestones.vercel.app/api/data?start_date={start_date}&end_date={end_date}"
    
    try:
        response = requests.get(url)
        response.raise_for_status()
    
        data = response.json()
        
        if not data:
            print("There is no data in this range of date")
            df = None
        else:
            print(f"Extraction process successful for milestones table")
            df = spark.createDataFrame(data)
        
        # Set success log message
        log_message = spark.sparkContext \
            .parallelize([("sources", "extraction", "success", "api", "milestones", current_timestamp)]) \
            .toDF(["step", "process", "status", "source", "table_name", "etl_date"])
    
        
    except Exception as e:
        print(f"Extraction process failed: {e}")
        df = None

        # Set failed log message
        log_message = spark.sparkContext \
            .parallelize([("sources", "extraction", "failed", "api", "milestones", current_timestamp, str(e))]) \
            .toDF(["step", "process", "status", "source", "table_name", "etl_date", "error_msg"])

    finally:
        load_log_msg(spark=spark, log_msg=log_message)

    return df
        

In [11]:
milestones_df = extract_api(spark=spark, start_date="2014-01-01", end_date="2015-01-01")

Extraction process successful for milestones table


## Data Profiling

In [12]:
# Check Percentage of Missing Values for each column with pyspark
import pandas as pd
import json
from pyspark.sql.functions import col, count, when, round

def check_missing_values(df):

    total_data = df.count()

    # Calculate the percentage of missing values for each column
    get_missing_values = df.select([
        round((count(when(col(column_name).isNull(), column_name)) / total_data) * 100, 2).alias(column_name)
        for column_name in df.columns
    ]).collect()[0].asDict()
    
    return get_missing_values

In [13]:
data_profiling_report = {
    "Created by" : "Rico Febrian",
    "Checking Date" : datetime.now().strftime('%d/%m/%y'),
    "Column Information": {
        "Acquisition": {"count": len(acquisition_df.columns), "columns": acquisition_df.columns},
        "Company": {"count": len(company_df.columns), "columns": company_df.columns},
        "Funding Rounds": {"count": len(funding_rounds_df.columns), "columns": funding_rounds_df.columns},
        "Funds": {"count": len(funds_df.columns), "columns": funds_df.columns},
        "Investments": {"count": len(investments_df.columns), "columns": investments_df.columns},
        "IPOS": {"count": len(ipos_df.columns), "columns": ipos_df.columns},
        "People": {"count": len(people_df.columns), "columns": people_df.columns},
        "Relationships": {"count": len(relationship_df.columns), "columns": relationship_df.columns},
        "Milestones": {"count": len(milestones_df.columns), "columns": milestones_df.columns}
    },
    "Check Data Size": {
        "Acquisition": acquisition_df.count(),
        "Company": company_df.count(),
        "Funding Rounds": funding_rounds_df.count(),
        "Funds": funds_df.count(),
        "Investments": investments_df.count(),
        "IPOS": ipos_df.count(),
        "People": people_df.count(),
        "Relationships": relationship_df.count(),
        "Milestones": milestones_df.count()
    },
    "Data Type For Each Column" : {
        "Acquisition": acquisition_df.dtypes,
        "Company": company_df.dtypes,
        "Funding Rounds": funding_rounds_df.dtypes,
        "Funds": funds_df.dtypes,
        "Investments": investments_df.dtypes,
        "IPOS": ipos_df.dtypes,
        "People": people_df.dtypes,
        "Relationships": relationship_df.dtypes,
        "Milestones": milestones_df.dtypes
    },
    "Check Missing Value" : {
        "Acquisition": check_missing_values(acquisition_df),
        "Company": check_missing_values(company_df),
        "Funding Rounds": check_missing_values(funding_rounds_df),
        "Funds": check_missing_values(funds_df),
        "Investments": check_missing_values(investments_df),
        "IPOS": check_missing_values(ipos_df),
        "People": check_missing_values(people_df),
        "Relationships": check_missing_values(relationship_df),
        "Milestones": check_missing_values(milestones_df)
    }
}

# Print dalam format JSON yang rapi
# print(json.dumps(data_profiling_report, indent=4))

In [14]:
# Create a function to save the final report to a JSON file
def save_to_json(dict_result: dict, filename: str) -> None:
    """
    This function saves the data profiling result to a JSON file.

    Args:
        dict_result (dict): Data profiling result to save to a JSON file.
        filename (str): Name of the JSON file to save the data profiling result to.

    Returns:
        None
    """

    try:
        
        # Save the data profiling result to a JSON file
        with open(f'{filename}.json', 'w') as file:
            file.write(json.dumps(dict_result, indent= 4))
    
    except Exception as e:
        print(f"Error: {e}")

In [15]:
save_to_json(dict_result=data_profiling_report, filename="data_profiling_report")

## Dump to Data Lake (MinIO)

In [16]:
# def load_to_minio(df, minio_path):

#     try:
#         df.write \
#           .mode("overwrite") \
#           .parquet(f"s3a://{minio_path}")

#         print(f"Initial load completed to {minio_path}")

#     except Exception as e:
#         print(f"Failed to load data: {e}")

In [17]:
# load_to_minio(df=acquisition_df, minio_path="raw-data/acquisition_table")
# load_to_minio(df=company_df, minio_path="raw-data/company_table")
# load_to_minio(df=funding_rounds_df, minio_path="raw-data/funding_rounds_table")
# load_to_minio(df=funds_df, minio_path="raw-data/funds_table")
# load_to_minio(df=investments_df, minio_path="raw-data/investments_table")
# load_to_minio(df=ipos_df, minio_path="raw-data/ipos_table")
# load_to_minio(df=people_df, minio_path="raw-data/people_table")
# load_to_minio(df=relationship_df, minio_path="raw-data/relationships_table")
# load_to_minio(df=milestones_df, minio_path="raw-data/milestones_table")

## Load to Staging

In [18]:
def load_to_stg(spark, df, table_name, source_name):

    current_timestamp = datetime.now()

    try:

        # Establish connection to staging db
        conn = staging_engine_sqlalchemy()

        with conn.begin() as connection:

            # Truncate all tables in data warehouse
            connection.execute(text(f"TRUNCATE TABLE {table_name} RESTART IDENTITY CASCADE"))

        print(f"Success truncating table: {table_name}")

    except Exception as e:
        print(f"Error when truncating table: {e}")

        log_message = spark.sparkContext\
            .parallelize([("staging", "load", "failed", source_name, table_name, current_timestamp, str(e))])\
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date', 'error_msg'])
        
        load_log_msg(spark=spark, log_msg=log_message) 

    finally:
        conn.dispose()

    # Load extarcted DataFrame to staging db
    try:
        
        # Get staging db config
        STG_DB_URL, STG_DB_USER, STG_DB_PASS = staging_engine()
    
        # Set config
        properties = {
            "user" : STG_DB_USER,
            "password" : STG_DB_PASS,
        }

        df.write.jdbc(url=STG_DB_URL,
                      table=table_name,
                      mode="append",
                      properties=properties)

        print(f"Load process successful for table: {table_name}")

        log_message = spark.sparkContext\
            .parallelize([("staging", "load", "success", source_name, table_name, current_timestamp)])\
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date'])
        
        load_log_msg(spark=spark, log_msg=log_message) 
        
    except Exception as e:
        print(f"Load process failed: {e}")

        log_message = spark.sparkContext\
            .parallelize([("staging", "load", "failed", source_name, table_name, current_timestamp, str(e))])\
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date', 'error_msg'])
        
    finally:
        load_log_msg(spark=spark, log_msg=log_message) 

In [19]:
load_to_stg(spark=spark, df=people_df, table_name="people", source_name="csv")
load_to_stg(spark=spark, df=relationship_df, table_name="relationships", source_name="csv")
load_to_stg(spark=spark, df=company_df, table_name="company", source_name="source_db")
load_to_stg(spark=spark, df=funding_rounds_df, table_name="funding_rounds", source_name="source_db")
load_to_stg(spark=spark, df=funds_df, table_name="funds", source_name="source_db")
load_to_stg(spark=spark, df=acquisition_df, table_name="acquisition", source_name="source_db")
load_to_stg(spark=spark, df=ipos_df, table_name="ipos", source_name="source_db")
load_to_stg(spark=spark, df=investments_df, table_name="investments", source_name="source_db")
load_to_stg(spark=spark, df=milestones_df, table_name="milestones", source_name="api")

Success truncating table: people
Load process successful for table: people
Success truncating table: relationships
Load process successful for table: relationships
Success truncating table: company
Load process successful for table: company
Success truncating table: funding_rounds
Load process successful for table: funding_rounds
Success truncating table: funds
Load process successful for table: funds
Success truncating table: acquisition
Load process successful for table: acquisition
Success truncating table: ipos
Load process successful for table: ipos
Success truncating table: investments
Load process successful for table: investments
Success truncating table: milestones
Load process successful for table: milestones


In [20]:
# spark.stop()

## Data Transformation

In [23]:
from pyspark.sql import functions as F

def transform_company(spark, df):
    
    current_timestamp = datetime.now()
    
    try:
        
        # Add new column, entity type

        df = df.withColumn(
            "entity_type",
            F.when(F.col("object_id").startswith("c:"), "company")
             .when(F.col("object_id").startswith("f:"), "fund")
             .otherwise(None)
        )

        # Concat address

        df = df.withColumn(
            "full_address",
            F.concat_ws(", ", F.col("address1"), F.col("address2"))
        )

        # Select columns
        dim_company = df.select(
            F.col("object_id").alias("nk_company_id"),
            F.col("entity_type"),
            F.col("full_address"),
            F.col("region"),
            F.col("city"),
            F.col("country_code")
        )

        #log message
        log_msg = spark.sparkContext\
        .parallelize([("warehouse", "transform", "success", "staging", "dim_company", current_timestamp)])\
        .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date'])

        return dim_company
        
    except Exception as e:
            #log message
            print(e)
            log_msg = spark.sparkContext\
            .parallelize([("warehouse", "transform", "success", "staging", "dim_company", current_timestamp, str(e))])\
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date', 'error_msg'])
        
    

In [24]:
dim_company = transform_company(spark, df=company_df)

In [135]:
def transform_people(spark, df):
    
    current_timestamp = datetime.now()
    
    try:
        
        # Concat first and last name

        df = df.withColumn(
            "full_name",
            F.concat_ws(" ", F.col("first_name"), F.col("last_name"))
        )

        # Select columns
        dim_people = df.select(
            F.col("object_id").alias("nk_people_id"),
            F.col("full_name"),
            F.col("affiliation_name")
        )

        #log message
        log_msg = spark.sparkContext\
        .parallelize([("warehouse", "transform", "success", "staging", "dim_people", current_timestamp)])\
        .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date'])

        return dim_people
        
    except Exception as e:
            #log message
            print(e)
            log_msg = spark.sparkContext\
            .parallelize([("warehouse", "transform", "success", "staging", "dim_people", current_timestamp, str(e))])\
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date', 'error_msg'])
        
    

In [136]:
dim_people = transform_people(spark, df=people_df)

In [137]:
dim_people.show()

+------------+----------------+--------------------+
|nk_people_id|       full_name|    affiliation_name|
+------------+----------------+--------------------+
|         p:2|     Ben Elowitz|           Blue Nile|
|         p:3|  Kevin Flaherty|            Wetpaint|
|         p:4|    Raju Vegesna|                Zoho|
|         p:5|       Ian Wenig|                Zoho|
|         p:6|      Kevin Rose|        i/o Ventures|
|         p:7|     Jay Adelson|                Digg|
|         p:8|      Owen Byrne|                Digg|
|         p:9|  Ron Gorodetzky|                Digg|
|        p:10| Mark Zuckerberg|            Facebook|
|        p:11|Dustin Moskovitz|            Facebook|
|        p:12|  Owen Van Natta|               Asana|
|        p:13|     Matt Cohler|            LinkedIn|
|        p:14|    Chris Hughes|General Catalyst ...|
|        p:16|      Alex Welch|            C7 Group|
|        p:17|  Darren Crystal|         Photobucket|
|        p:18|   Michael Clark|   Photobucket 

In [68]:
def extract_warehouse(spark: SparkSession, table_name):
    # get config
    DWH_DB_URL, DWH_DB_USER, DWH_DB_PASS = dwh_engine()

    # set config
    connection_properties = {
        "user": DWH_DB_USER,
        "password": DWH_DB_PASS,
        "driver": "org.postgresql.Driver" # set driver postgres
    }
    
    try:
        # read data
        df = spark \
                .read \
                .jdbc(url = DWH_DB_URL,
                        table = table_name,
                        properties = connection_properties)
        return df
    except Exception as e:
        print(e)

In [114]:
def transform_fund(spark, df):
    
    current_timestamp = datetime.now()
    
    try:

        # Extract dim_date
        dim_date = extract_warehouse(spark, table_name="dim_date")

        # Rename columns
        df = df.withColumnRenamed("name", "fund_name") \
               .withColumnRenamed("source_description", "fund_description")

        # Add new column and convert all raised amount to USD using when/otherwise
        df = df.withColumn(
            "raised_amount_usd",
            F.round(
                F.when(F.col("raised_currency_code") == "USD", F.col("raised_amount"))
                 .when(F.col("raised_currency_code") == "CAD", F.col("raised_amount") * 0.72)
                 .when(F.col("raised_currency_code") == "EUR", F.col("raised_amount") * 1.14)
                 .when(F.col("raised_currency_code") == "SEK", F.col("raised_amount") * 0.10)
                 .when(F.col("raised_currency_code") == "AUD", F.col("raised_amount") * 0.64)
                 .when(F.col("raised_currency_code") == "JPY", F.col("raised_amount") * 0.007)
                 .when(F.col("raised_currency_code") == "GBP", F.col("raised_amount") * 1.33)
                 .otherwise(F.col("raised_amount")),
                2
            )
        )

        # Add Foreign Key to dim_date

        # Generate new column based on funded_at with same format with date_id in dim_date table
        df = df.withColumn(
            "funded_date_id",
            F.date_format(df.funded_at, "yyyyMMdd").cast("integer")
        )

        # Get date_id from dim_date based on funded_date_id column
        df = df.join(
            dim_date,
            df.funded_date_id == dim_date.date_id,
            "left"
        )

        # Remove empty string in fund_description column and mark as null
        df = df.withColumn(
            "fund_description",
            F.when(F.trim(df.fund_description) == "", None)
             .otherwise(df.fund_description)
        )
        
        # Select columns
        dim_fund = df.select(
            F.col("object_id").alias("nk_fund_id"),
            F.col("fund_name"),
            F.col("raised_amount_usd"),
            F.col("funded_date_id").alias("funded_at"),
            F.col("fund_description")
        )

        #log message
        log_msg = spark.sparkContext\
        .parallelize([("warehouse", "transform", "success", "staging", "dim_fund", current_timestamp)])\
        .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date'])

        return dim_fund
        
    except Exception as e:
            #log message
            print(e)
            log_msg = spark.sparkContext\
            .parallelize([("warehouse", "transform", "success", "staging", "dim_fund", current_timestamp, str(e))])\
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date', 'error_msg'])
        
    

In [115]:
dim_fund = transform_fund(spark, df=funds_df)

In [116]:
load_to_dwh(spark, df=dim_fund, table_name="dim_fund", source_name="staging_db")

Success truncating table: dim_fund
Load process successful for table: dim_fund


In [147]:
def transform_relationship(spark, df):
    
    current_timestamp = datetime.now()
    
    try:
        # sk_company_id,
        # sk_people_id,
        # title,
        # is_past,
        # relationship_start_at,
        # relationship_end_at
        
        # Extract selected tables
        dim_date = extract_warehouse(spark, table_name="dim_date")
        dim_company = extract_warehouse(spark, table_name="dim_company")
        dim_people = extract_warehouse(spark, table_name="dim_people")

        # Add Foreign Key to dim_company

        # Create sk_company_id column based on dim_company
        df = df.join(
            dim_company.select("sk_company_id", "nk_company_id"),
            df.relationship_object_id == dim_company.nk_company_id,
            "inner"
        )

        # Create sk_people_id column based on dim_people
        df = df.join(
            dim_people.select("sk_people_id", "nk_people_id"),
            df.person_object_id == dim_people.nk_people_id,
            "inner"
        )

        # .....
        df = df.withColumn("relationship_start_at", F.date_format("start_at", "yyyyMMdd").cast("integer")) \
               .withColumn("relationship_end_at", F.date_format("end_at", "yyyyMMdd").cast("integer"))

        dim_date_start = dim_date.alias("start_date")
        dim_date_end = dim_date.alias("end_date")
        
        df = df.join(
            dim_date_start,
            df.relationship_start_at == F.col("start_date.date_id"),
            "inner"
        ).join(
            dim_date_end,
            df.relationship_end_at == F.col("end_date.date_id"),
            "inner"
        )
        

        # Select columns
        bridge_company_people = df.select(
            F.col("sk_company_id"),
            F.col("sk_people_id"),
            F.col("title"),
            F.col("is_past"),
            F.col("relationship_start_at"),
            F.col("relationship_end_at")
        )
        
        #log message
        log_msg = spark.sparkContext\
        .parallelize([("warehouse", "transform", "success", "staging", "bridge_company_people", current_timestamp)])\
        .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date'])

        return bridge_company_people
        
    except Exception as e:
            #log message
            print(e)
            log_msg = spark.sparkContext\
            .parallelize([("warehouse", "transform", "success", "staging", "bridge_company_people", current_timestamp, str(e))])\
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date', 'error_msg'])
        
    

In [148]:
bridge_company_people = transform_relationship(spark, df=relationship_df)

bridge_company_people.show()

+-------------+------------+--------------------+-------+---------------------+-------------------+
|sk_company_id|sk_people_id|               title|is_past|relationship_start_at|relationship_end_at|
+-------------+------------+--------------------+-------+---------------------+-------------------+
|            1|      180703|Sr. EMEA Operatio...|  false|             20050101|           20080101|
|            1|       53575|        Board Member|  false|             20010101|           20070101|
|            1|         599|Executive Vice Pr...|  false|             20000801|           20070901|
|            2|      168571|          Co-founder|  false|             20060101|           20100101|
|            2|      101437|Designer & Co-fou...|  false|             20060701|           20070201|
|            2|      101437|          VP Product|  false|             20070201|           20090601|
|            2|      101437|                 COO|  false|             20090601|           20100701|


In [151]:
load_to_dwh(spark, df=bridge_company_people, table_name="bridge_company_people", source_name="staging_db")

Success truncating table: bridge_company_people
Load process successful for table: bridge_company_people


## Load to Warehouse

In [32]:
def load_to_dwh(spark, df, table_name, source_name):

    current_timestamp = datetime.now()

    try:

        # Establish connection to warehouse db
        conn = dwh_engine_sqlalchemy()

        with conn.begin() as connection:

            # Truncate all tables in data warehouse
            connection.execute(text(f"TRUNCATE TABLE {table_name} RESTART IDENTITY CASCADE"))

        print(f"Success truncating table: {table_name}")

    except Exception as e:
        print(f"Error when truncating table: {e}")

        log_message = spark.sparkContext\
            .parallelize([("warehouse", "load", "failed", source_name, table_name, current_timestamp, str(e))])\
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date', 'error_msg'])
        
        load_log_msg(spark=spark, log_msg=log_message) 

    finally:
        conn.dispose()

    # Load extarcted DataFrame to warehouse db
    try:
        
        # Get warehouse db config
        DWH_DB_URL, DWH_DB_USER, DWH_DB_PASS = dwh_engine()
    
        # Set config
        properties = {
            "user" : DWH_DB_USER,
            "password" : DWH_DB_PASS,
        }

        df.write.jdbc(url=DWH_DB_URL,
                      table=table_name,
                      mode="append",
                      properties=properties)

        print(f"Load process successful for table: {table_name}")

        log_message = spark.sparkContext\
            .parallelize([("warehouse", "load", "success", source_name, table_name, current_timestamp)])\
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date'])
        
        load_log_msg(spark=spark, log_msg=log_message) 
        
    except Exception as e:
        print(f"Load process failed: {e}")

        log_message = spark.sparkContext\
            .parallelize([("warehouse", "load", "failed", source_name, table_name, current_timestamp, str(e))])\
            .toDF(['step', 'process', 'status', 'source', 'table_name', 'etl_date', 'error_msg'])
        
    finally:
        load_log_msg(spark=spark, log_msg=log_message) 

In [34]:
load_to_dwh(spark, df=dim_company ,table_name="dim_company", source_name="staging_db")

Success truncating table: dim_company
Load process successful for table: dim_company


In [139]:
load_to_dwh(spark, df=dim_people, table_name="dim_people", source_name="staging_db")

Success truncating table: dim_people
Load process successful for table: dim_people
