1. Customer Data
customer_id: string
first_name: string
last_name: string
date_of_birth: date
email: string
phone_number: string
address: string
city: string
state: string
zip_code: string
country: string
customer_since: date
credit_score: integer
risk_segment: string

In [0]:
1. Customer Data
customer_id: string
first_name: string
last_name: string
date_of_birth: date
email: string
phone_number: string
address: string
city: string
state: string
zip_code: string
country: string
customer_since: date
credit_score: integer
risk_segment: string

In [0]:
2. Account Data
account_id: string
customer_id: string
account_type: string (checking, savings, investment)
account_status: string (active, closed, suspended)
open_date: date
close_date: date
currency: string
branch_id: string
interest_rate: float
balance: decimal
last_activity_date: date

In [0]:
3. Transaction Data
transaction_id: string
account_id: string
transaction_date: timestamp
transaction_type: string (deposit, withdrawal, transfer, payment)
amount: decimal
currency: string
description: string
merchant_name: string
merchant_category: string
transaction_status: string (completed, pending, failed, reversed)
channel: string (online, mobile, branch, atm)
location: string
is_international: boolean

In [0]:
4. Credit Card Data
card_id: string
customer_id: string
account_id: string
card_type: string (visa, mastercard, amex)
card_status: string (active, blocked, expired)
issue_date: date
expiry_date: date
credit_limit: decimal
current_balance: decimal
available_credit: decimal
last_payment_date: date
last_payment_amount: decimal
interest_rate: float
reward_points: integer

In [0]:
5. Loan Data
loan_id: string
customer_id: string
loan_type: string (personal, mortgage, auto, business)
loan_status: string (active, paid, defaulted)
loan_amount: decimal
interest_rate: float
term_months: integer
start_date: date
end_date: date
monthly_payment: decimal
remaining_balance: decimal
next_payment_date: date
collateral_value: decimal
collateral_type: string

In [0]:
1. Spark Session Initialization
# src/utils/spark_session.py

from pyspark.sql import SparkSession
from pyspark.sql.types import *
import os

def create_spark_session(app_name="Banking ETL Pipeline"):
    """
    Create and configure a Spark session for the ETL pipeline.
    
    Args:
        app_name (str): Name of the Spark application
        
    Returns:
        SparkSession: Configured Spark session
    """
    # For Databricks, we can use the existing spark session
    if "DATABRICKS_RUNTIME_VERSION" in os.environ:
        return SparkSession.builder.getOrCreate()
    
    # For local or AWS EMR execution
    return (SparkSession.builder
            .appName(app_name)
            .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
            .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog")
            .config("spark.databricks.delta.retentionDurationCheck.enabled", "false")
            .config("spark.sql.legacy.timeParserPolicy", "LEGACY")
            .config("spark.sql.legacy.parquet.datetimeRebaseModeInWrite", "LEGACY")
            .config("spark.sql.warehouse.dir", "s3://your-data-lake-bucket/warehouse")
            .config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem")
            .config("spark.hadoop.fs.s3a.aws.credentials.provider", 
                   "com.amazonaws.auth.DefaultAWSCredentialsProviderChain")
            .config("spark.hadoop.fs.s3a.connection.maximum", "100")
            .config("spark.sql.adaptive.enabled", "true")
            .config("spark.sql.adaptive.coalescePartitions.enabled", "true")
            .config("spark.sql.shuffle.partitions", "200")
            .config("spark.default.parallelism", "200")
            .enableHiveSupport()
            .getOrCreate())

In [0]:
2. Data Ingestion from S3
# src/ingestion/s3_connector.py

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import *
import logging

logger = logging.getLogger(__name__)

class S3Connector:
    """Class to handle data ingestion from AWS S3."""
    
    def __init__(self, spark: SparkSession, bucket_name: str):
        """
        Initialize S3 connector.
        
        Args:
            spark (SparkSession): Spark session
            bucket_name (str): S3 bucket name
        """
        self.spark = spark
        self.bucket_name = bucket_name
        
    def read_csv(self, file_path: str, header: bool = True, infer_schema: bool = True) -> DataFrame:
        """
        Read CSV file from S3.
        
        Args:
            file_path (str): Path to the CSV file in S3
            header (bool): Whether the CSV has a header
            infer_schema (bool): Whether to infer the schema
            
        Returns:
            DataFrame: Spark DataFrame containing the data
        """
        try:
            full_path = f"s3a://{self.bucket_name}/{file_path}"
            logger.info(f"Reading CSV file from {full_path}")
            
            return (self.spark.read
                   .option("header", header)
                   .option("inferSchema", infer_schema)
                   .csv(full_path))
        except Exception as e:
            logger.error(f"Error reading CSV file from {full_path}: {str(e)}")
            raise
            
    def read_parquet(self, file_path: str) -> DataFrame:
        """
        Read Parquet file from S3.
        
        Args:
            file_path (str): Path to the Parquet file in S3
            
        Returns:
            DataFrame: Spark DataFrame containing the data
        """
        try:
            full_path = f"s3a://{self.bucket_name}/{file_path}"
            logger.info(f"Reading Parquet file from {full_path}")
            
            return self.spark.read.parquet(full_path)
        except Exception as e:
            logger.error(f"Error reading Parquet file from {full_path}: {str(e)}")
            raise
            
    def read_delta(self, file_path: str) -> DataFrame:
        """
        Read Delta table from S3.
        
        Args:
            file_path (str): Path to the Delta table in S3
            
        Returns:
            DataFrame: Spark DataFrame containing the data
        """
        try:
            full_path = f"s3a://{self.bucket_name}/{file_path}"
            logger.info(f"Reading Delta table from {full_path}")
            
            return self.spark.read.format("delta").load(full_path)
        except Exception as e:
            logger.error(f"Error reading Delta table from {full_path}: {str(e)}")
            raise

In [0]:
3. Data Transformation for Transactions
# src/transformation/transaction_transform.py

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import *
from pyspark.sql.window import Window
import logging

logger = logging.getLogger(__name__)

class TransactionTransformer:
    """Class to handle transaction data transformations."""
    
    def __init__(self, spark: SparkSession):
        """
        Initialize transaction transformer.
        
        Args:
            spark (SparkSession): Spark session
        """
        self.spark = spark
        
    def clean_transaction_data(self, df: DataFrame) -> DataFrame:
        """
        Clean transaction data by handling missing values and data type conversions.
        
        Args:
            df (DataFrame): Raw transaction data
            
        Returns:
            DataFrame: Cleaned transaction data
        """
        logger.info("Cleaning transaction data")
        
        # Convert date strings to timestamp
        df = df.withColumn("transaction_date", 
                          to_timestamp(col("transaction_date"), "yyyy-MM-dd HH:mm:ss"))
        
        # Handle missing values
        df = df.na.fill("Unknown", ["merchant_name", "merchant_category", "description"])
        
        # Filter out invalid transactions (e.g., negative amounts for deposits)
        df = df.filter(~((col("transaction_type") == "deposit") & (col("amount") < 0)))
        
        # Standardize transaction types
        df = df.withColumn("transaction_type", 
                          when(col("transaction_type").isin("deposit", "DEPOSIT", "Deposit"), "deposit")
                          .when(col("transaction_type").isin("withdrawal", "WITHDRAWAL", "Withdrawal"), "withdrawal")
                          .when(col("transaction_type").isin("transfer", "TRANSFER", "Transfer"), "transfer")
                          .when(col("transaction_type").isin("payment", "PAYMENT", "Payment"), "payment")
                          .otherwise(col("transaction_type")))
        
        return df
    
    def enrich_transaction_data(self, df: DataFrame) -> DataFrame:
        """
        Enrich transaction data with additional features.
        
        Args:
            df (DataFrame): Cleaned transaction data
            
        Returns:
            DataFrame: Enriched transaction data
        """
        logger.info("Enriching transaction data")
        
        # Extract date components
        df = df.withColumn("transaction_year", year(col("transaction_date")))
        df = df.withColumn("transaction_month", month(col("transaction_date")))
        df = df.withColumn("transaction_day", dayofmonth(col("transaction_date")))
        df = df.withColumn("transaction_hour", hour(col("transaction_date")))
        df = df.withColumn("transaction_dayofweek", dayofweek(col("transaction_date")))
        
        # Flag for weekend transactions
        df = df.withColumn("is_weekend", 
                          when(col("transaction_dayofweek").isin(1, 7), True)
                          .otherwise(False))
        
        # Calculate transaction amount in USD (assuming currency conversion)
        df = df.withColumn("amount_usd", 
                          when(col("currency") == "USD", col("amount"))
                          .when(col("currency") == "EUR", col("amount") * 1.1)
                          .when(col("currency") == "GBP", col("amount") * 1.3)
                          .otherwise(col("amount")))
        
        # Add transaction category based on merchant category
        df = df.withColumn("transaction_category",
                          when(col("merchant_category").isin("grocery", "supermarket"), "Groceries")
                          .when(col("merchant_category").isin("restaurant", "fast food"), "Dining")
                          .when(col("merchant_category").isin("gas", "fuel"), "Transportation")
                          .when(col("merchant_category").isin("utility", "electricity", "water"), "Utilities")
                          .otherwise("Other"))
        
        return df
    
    def calculate_transaction_metrics(self, df: DataFrame) -> DataFrame:
        """
        Calculate transaction metrics like running balances and spending patterns.
        
        Args:
            df (DataFrame): Enriched transaction data
            
        Returns:
            DataFrame: Transaction data with metrics
        """
        logger.info("Calculating transaction metrics")
        
        # Define window for running calculations by account
        window_spec = Window.partitionBy("account_id").orderBy("transaction_date")
        
        # Calculate running balance
        df = df.withColumn("amount_signed", 
                          when(col("transaction_type").isin("deposit", "transfer_in"), col("amount_usd"))
                          .otherwise(-col("amount_usd")))
        
        df = df.withColumn("running_balance", sum("amount_signed").over(window_spec))
        
        # Calculate days since last transaction
        df = df.withColumn("prev_transaction_date", 
                          lag("transaction_date", 1).over(window_spec))
        
        df = df.withColumn("days_since_last_transaction", 
                          when(col("prev_transaction_date").isNull(), 0)
                          .otherwise(datediff(col("transaction_date"), col("prev_transaction_date"))))
        
        # Calculate transaction frequency metrics
        window_30d = Window.partitionBy("account_id")\
                           .orderBy("transaction_date")\
                           .rangeBetween(-30 * 86400, 0)  # 30 days in seconds
        
        df = df.withColumn("transaction_count_30d", count("transaction_id").over(window_30d))
        df = df.withColumn("total_spend_30d", 
                          sum(when(col("transaction_type").isin("withdrawal", "payment"), col("amount_usd"))
                             .otherwise(0)).over(window_30d))
        
        return df
    
    def detect_anomalies(self, df: DataFrame) -> DataFrame:
        """
        Detect anomalous transactions based on various rules.
        
        Args:
            df (DataFrame): Transaction data with metrics
            
        Returns:
            DataFrame: Transaction data with anomaly flags
        """
        logger.info("Detecting anomalous transactions")
        
        # Calculate account-level statistics
        account_stats = df.groupBy("account_id").agg(
            stddev("amount_usd").alias("amount_stddev"),
            avg("amount_usd").alias("amount_avg"),
            max("amount_usd").alias("amount_max")
        )
        
        # Join with transaction data
        df = df.join(account_stats, on="account_id", how="left")
        
        # Flag large transactions (> 3 standard deviations from mean)
        df = df.withColumn("is_large_transaction", 
                          (col("amount_usd") > (col("amount_avg") + 3 * col("amount_stddev"))) &
                          (col("amount_usd") > 1000))
        
        # Flag transactions in unusual locations
        df = df.withColumn("is_unusual_location", 
                          col("is_international") & 
                          ~col("location").isin("Canada", "Mexico", "United Kingdom", "France", "Germany"))
        
        # Flag high-frequency transactions
        df = df.withColumn("is_high_frequency", 
                          col("transaction_count_30d") > 100)
        
        # Flag potential fraud based on combined factors
        df = df.withColumn("potential_fraud", 
                          col("is_large_transaction") | 
                          col("is_unusual_location") |
                          (col("days_since_last_transaction") < 0.01))  # Multiple transactions in seconds
        
        return df

In [0]:
Explanation: The TransactionTransformer class handles the transformation of transaction data through several key methods:

clean_transaction_data(): Performs basic data cleaning operations:
- Converts string dates to proper timestamp format
- Fills missing values with defaults for text fields
- Removes invalid transactions (like negative deposit amounts)
- Standardizes transaction type values for consistency
enrich_transaction_data(): Adds valuable derived features:
- Extracts date components (year, month, day, hour, day of week)
- Adds a weekend flag for time-based analysis
- Converts transaction amounts to USD for consistent analysis
- Categorizes transactions based on merchant category
calculate_transaction_metrics(): Computes advanced metrics using window functions:
- Calculates running balance for each account
- Determines days between transactions
- Computes 30-day transaction counts and spending totals
detect_anomalies(): Identifies potentially suspicious transactions:
- Calculates statistical metrics for each account
- Flags unusually large transactions (statistical outliers)
- Identifies transactions in unusual locations
- Detects high-frequency transaction patterns
- Creates a composite fraud indicator based on multiple factors

In [0]:
4. Data Quality Checks
# src/transformation/data_quality.py

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import *
from typing import Dict, List, Tuple
import logging

logger = logging.getLogger(__name__)

class DataQualityChecker:
    """Class to perform data quality checks on datasets."""
    
    def __init__(self, spark: SparkSession):
        """
        Initialize data quality checker.
        
        Args:
            spark (SparkSession): Spark session
        """
        self.spark = spark
        
    def check_nulls(self, df: DataFrame, required_columns: List[str]) -> Tuple[bool, Dict[str, int]]:
        """
        Check for null values in required columns.
        
        Args:
            df (DataFrame): DataFrame to check
            required_columns (List[str]): List of columns that should not have nulls
            
        Returns:
            Tuple[bool, Dict[str, int]]: (passed/failed, dict of null counts by column)
        """
        logger.info(f"Checking for nulls in columns: {required_columns}")
        
        # Count nulls in each column
        null_counts = {}
        for column in required_columns:
            if column in df.columns:
                null_count = df.filter(col(column).isNull()).count()
                null_counts[column] = null_count
            else:
                logger.warning(f"Column {column} not found in DataFrame")
                null_counts[column] = "Column not found"
        
        # Check if any required column has nulls
        has_nulls = any(isinstance(count, int) and count > 0 for count in null_counts.values())
        
        if has_nulls:
            logger.warning(f"Null check failed. Null counts: {null_counts}")
            return False, null_counts
        else:
            logger.info("Null check passed")
            return True, null_counts
    
    def check_duplicates(self, df: DataFrame, key_columns: List[str]) -> Tuple[bool, int]:
        """
        Check for duplicate records based on key columns.
        
        Args:
            df (DataFrame): DataFrame to check
            key_columns (List[str]): Columns that should form a unique key
            
        Returns:
            Tuple[bool, int]: (passed/failed, count of duplicate records)
        """
        logger.info(f"Checking for duplicates on key columns: {key_columns}")
        
        # Count total rows
        total_rows = df.count()
        
        # Count distinct rows based on key columns
        distinct_rows = df.select(key_columns).distinct().count()
        
        # Calculate duplicates
        duplicate_count = total_rows - distinct_rows
        
        if duplicate_count > 0:
            logger.warning(f"Duplicate check failed. Found {duplicate_count} duplicates")
            return False, duplicate_count
        else:
            logger.info("Duplicate check passed")
            return True, 0
    
    def check_data_ranges(self, df: DataFrame, range_checks: Dict[str, Tuple]) -> Tuple[bool, Dict[str, int]]:
        """
        Check if values in columns fall within expected ranges.
        
        Args:
            df (DataFrame): DataFrame to check
            range_checks (Dict[str, Tuple]): Dictionary mapping column names to (min, max) tuples
            
        Returns:
            Tuple[bool, Dict[str, int]]: (passed/failed, dict of out-of-range counts by column)
        """
        logger.info(f"Checking data ranges for columns: {list(range_checks.keys())}")
        
        out_of_range_counts = {}
        
        for column, (min_val, max_val) in range_checks.items():
            if column in df.columns:
                # Count values outside the expected range
                out_of_range_count = df.filter(
                    (col(column) < min_val) | (col(column) > max_val)
                ).count()
                
                out_of_range_counts[column] = out_of_range_count
            else:
                logger.warning(f"Column {column} not found in DataFrame")
                out_of_range_counts[column] = "Column not found"
        
        # Check if any column has out-of-range values
        has_out_of_range = any(isinstance(count, int) and count > 0 for count in out_of_range_counts.values())
        
        if has_out_of_range:
            logger.warning(f"Range check failed. Out-of-range counts: {out_of_range_counts}")
            return False, out_of_range_counts
        else:
            logger.info("Range check passed")
            return True, out_of_range_counts
    
    def check_referential_integrity(self, df: DataFrame, ref_df: DataFrame, 
                                   fk_column: str, pk_column: str) -> Tuple[bool, int]:
        """
        Check referential integrity between two DataFrames.
        
        Args:
            df (DataFrame): DataFrame with foreign key
            ref_df (DataFrame): Reference DataFrame with primary key
            fk_column (str): Foreign key column in df
            pk_column (str): Primary key column in ref_df
            
        Returns:
            Tuple[bool, int]: (passed/failed, count of orphaned records)
        """
        logger.info(f"Checking referential integrity: {fk_column} -> {pk_column}")
        
        # Get distinct foreign keys
        fk_values = df.select(fk_column).distinct()
        
        # Get distinct primary keys
        pk_values = ref_df.select(pk_column).distinct()
        
        # Find orphaned records (foreign keys without matching primary keys)
        orphaned_records = fk_values.join(
            pk_values,
            fk_values[fk_column] == pk_values[pk_column],
            "left_anti"
        )
        
        orphaned_count = orphaned_records.count()
        
        if orphaned_count > 0:
            logger.warning(f"Referential integrity check failed. Found {orphaned_count} orphaned records")
            return False, orphaned_count
        else:
            logger.info("Referential integrity check passed")
            return True, 0
    
    def run_all_checks(self, df: DataFrame, check_config: Dict) -> Dict:
        """
        Run all configured data quality checks on a DataFrame.
        
        Args:
            df (DataFrame): DataFrame to check
            check_config (Dict): Configuration for checks to run
            
        Returns:
            Dict: Results of all checks
        """
        logger.info(f"Running all data quality checks for table: {check_config.get('table_name', 'unknown')}")
        
        results = {
            "table_name": check_config.get("table_name", "unknown"),
            "record_count": df.count(),
            "checks": {}
        }
        
        # Run null checks
        if "required_columns" in check_config:
            null_check_passed, null_counts = self.check_nulls(df, check_config["required_columns"])
            results["checks"]["null_check"] = {
                "passed": null_check_passed,
                "details": null_counts
            }
        
        # Run duplicate checks
        if "key_columns" in check_config:
            dup_check_passed, dup_count = self.check_duplicates(df, check_config["key_columns"])
            results["checks"]["duplicate_check"] = {
                "passed": dup_check_passed,
                "details": {"duplicate_count": dup_count}
            }
        
        # Run range checks
        if "range_checks" in check_config:
            range_check_passed, range_counts = self.check_data_ranges(df, check_config["range_checks"])
            results["checks"]["range_check"] = {
                "passed": range_check_passed,
                "details": range_counts
            }
        
        # Overall check result
        results["overall_passed"] = all(check["passed"] for check in results["checks"].values())
        
        return results

In [0]:
Explanation: The DataQualityChecker class implements comprehensive data quality validation, which is crucial for financial data processing. It provides several key methods:

check_nulls(): Validates that required columns don’t contain null values:
- Counts null values in each specified column
- Returns both a pass/fail status and detailed counts
- Handles missing columns gracefully
check_duplicates(): Ensures uniqueness of key columns:
- Compares total row count with distinct row count
- Identifies duplicate records based on specified key columns
- Returns both status and count of duplicates
check_data_ranges(): Verifies that values fall within expected ranges:
- Takes a dictionary mapping columns to min/max value ranges
- Counts out-of-range values for each column
- Returns detailed results by column
check_referential_integrity(): Validates foreign key relationships:
- Ensures all foreign keys have corresponding primary keys
- Uses anti-join to find orphaned records
- Returns count of integrity violations
run_all_checks(): Orchestrates multiple quality checks:
- Runs configured checks based on provided configuration
- Aggregates results into a structured report
- Determines overall pass/fail status
This class is essential for ensuring data reliability in a banking environment where data accuracy is critical for regulatory compliance and business operations.

In [0]:
5. Data Loading to Redshift
# src/loading/redshift_loader.py

from pyspark.sql import SparkSession, DataFrame
import logging
from typing import Dict, Optional

logger = logging.getLogger(__name__)

class RedshiftLoader:
    """Class to handle data loading to AWS Redshift."""
    
    def __init__(self, spark: SparkSession, jdbc_url: str, username: str, password: str):
        """
        Initialize Redshift loader.
        
        Args:
            spark (SparkSession): Spark session
            jdbc_url (str): JDBC URL for Redshift
            username (str): Redshift username
            password (str): Redshift password
        """
        self.spark = spark
        self.jdbc_url = jdbc_url
        self.username = username
        self.password = password
        
    def write_to_redshift(self, df: DataFrame, table_name: str, write_mode: str = "append",
                         preactions: Optional[str] = None, postactions: Optional[str] = None) -> None:
        """
        Write DataFrame to Redshift table.
        
        Args:
            df (DataFrame): DataFrame to write
            table_name (str): Target table name
            write_mode (str): Write mode (append, overwrite, error)
            preactions (Optional[str]): SQL to execute before writing
            postactions (Optional[str]): SQL to execute after writing
        """
        try:
            logger.info(f"Writing data to Redshift table: {table_name}")
            
            # Connection properties
            connection_properties = {
                "url": self.jdbc_url,
                "user": self.username,
                "password": self.password,
                "driver": "com.amazon.redshift.jdbc42.Driver",
                "dbtable": table_name
            }
            
            # Add pre/post actions if provided
            if preactions:
                connection_properties["preactions"] = preactions
            if postactions:
                connection_properties["postactions"] = postactions
                
            # Write to Redshift
            df.write.format("jdbc") \
                .mode(write_mode) \
                .options(**connection_properties) \
                .save()
                
            logger.info(f"Successfully wrote data to Redshift table: {table_name}")
        except Exception as e:
            logger.error(f"Error writing to Redshift table {table_name}: {str(e)}")
            raise
            
    def load_with_staging(self, df: DataFrame, target_table: str, staging_table: str = None,
                         key_columns: list = None) -> None:
        """
        Load data to Redshift using a staging table for better performance.
        
        Args:
            df (DataFrame): DataFrame to write
            target_table (str): Target table name
            staging_table (str): Staging table name (defaults to target_table + '_staging')
            key_columns (list): Primary key columns for merging
        """
        try:
            if not staging_table:
                staging_table = f"{target_table}_staging"
                
            logger.info(f"Loading data to Redshift using staging table: {staging_table}")
            
            # Create staging table with the same structure as target table
            create_staging_sql = f"""
            DROP TABLE IF EXISTS {staging_table};
            CREATE TABLE {staging_table} (LIKE {target_table});
            """
            
            # Write data to staging table
            self.write_to_redshift(df, staging_table, "overwrite", preactions=create_staging_sql)
            
            # Merge data from staging to target table
            if key_columns and len(key_columns) > 0:
                # For upsert operation
                key_conditions = " AND ".join([f"target.{col} = source.{col}" for col in key_columns])
                non_key_columns = [col for col in df.columns if col not in key_columns]
                update_statements = ", ".join([f"target.{col} = source.{col}" for col in non_key_columns])
                insert_columns = ", ".join(df.columns)
                insert_values = ", ".join([f"source.{col}" for col in df.columns])
                
                merge_sql = f"""
                BEGIN TRANSACTION;
                
                -- Update existing records
                UPDATE {target_table} AS target
                SET {update_statements}
                FROM {staging_table} AS source
                WHERE {key_conditions};
                
                -- Insert new records
                INSERT INTO {target_table} ({insert_columns})
                SELECT {insert_values}
                FROM {staging_table} AS source
                LEFT JOIN {target_table} AS target
                ON {key_conditions}
                WHERE target.{key_columns[0]} IS NULL;
                
                -- Clean up staging table
                DROP TABLE IF EXISTS {staging_table};
                
                END TRANSACTION;
                """
                
                # Execute the merge SQL
                self.execute_sql(merge_sql)
            else:
                # For full load/truncate and load
                truncate_and_load_sql = f"""
                BEGIN TRANSACTION;
                
                TRUNCATE TABLE {target_table};
                
                INSERT INTO {target_table}
                SELECT * FROM {staging_table};
                
                DROP TABLE IF EXISTS {staging_table};
                
                END TRANSACTION;
                """
                
                # Execute the truncate and load SQL
                self.execute_sql(truncate_and_load_sql)
                
            logger.info(f"Successfully loaded data to Redshift table: {target_table}")
        except Exception as e:
            logger.error(f"Error loading data to Redshift table {target_table}: {str(e)}")
            raise
            
    def execute_sql(self, sql: str) -> None:
        """
        Execute SQL statement in Redshift.
        
        Args:
            sql (str): SQL statement to execute
        """
        try:
            logger.info("Executing SQL in Redshift")
            
            # Create a temporary DataFrame to execute SQL
            temp_df = self.spark.createDataFrame([("dummy",)], ["dummy"])
            
            # Execute SQL as a postaction
            connection_properties = {
                "url": self.jdbc_url,
                "user": self.username,
                "password": self.password,
                "driver": "com.amazon.redshift.jdbc42.Driver",
                "dbtable": "(SELECT 1) AS dummy",
                "postactions": sql
            }
            
            temp_df.write.format("jdbc") \
                .mode("append") \
                .options(**connection_properties) \
                .save()
                
            logger.info("Successfully executed SQL in Redshift")
        except Exception as e:
            logger.error(f"Error executing SQL in Redshift: {str(e)}")
            raise

In [0]:
Explanation: The RedshiftLoader class handles loading processed data into AWS Redshift, a popular data warehouse solution. It provides several key methods:

write_to_redshift(): Basic method to write data directly to a Redshift table:
- Sets up JDBC connection properties
- Supports different write modes (append, overwrite)
- Allows pre and post SQL actions
- Includes comprehensive error handling
load_with_staging(): Implements a more sophisticated loading pattern using staging tables:
- Creates a temporary staging table
- Loads data to the staging table
- Performs either an upsert operation (for incremental loads) or a full replacement
- Uses transactions to ensure atomicity
- Cleans up staging tables after successful load
execute_sql(): Utility method to execute arbitrary SQL in Redshift:
- Creates a dummy DataFrame as a vehicle for SQL execution
- Uses Spark’s JDBC connector to run the SQL
- Provides error handling and logging
This class implements best practices for loading data to Redshift, including:

Using staging tables to minimize impact on production tables
Supporting both full and incremental loading patterns
Using transactions to ensure data consistency
Providing detailed logging for troubleshooting
These patterns are particularly important in banking applications where data integrity is critical.

In [0]:
6. Main Pipeline Orchestration
# src/orchestration/main_pipeline.py

from pyspark.sql import SparkSession
import logging
import json
import os
from datetime import datetime

# Import project modules
from src.utils.spark_session import create_spark_session
from src.ingestion.s3_connector import S3Connector
from src.ingestion.rds_connector import RDSConnector
from src.transformation.customer_transform import CustomerTransformer
from src.transformation.transaction_transform import TransactionTransformer
from src.transformation.account_transform import AccountTransformer
from src.transformation.data_quality import DataQualityChecker
from src.loading.redshift_loader import RedshiftLoader
from src.loading.s3_loader import S3Loader
from src.utils.logging_utils import setup_logging

logger = logging.getLogger(__name__)

class BankingETLPipeline:
    """Main class to orchestrate the banking ETL pipeline."""
    
    def __init__(self, config_path: str):
        """
        Initialize the ETL pipeline.
        
        Args:
            config_path (str): Path to the configuration file
        """
        # Set up logging
        setup_logging()
        
        # Load configuration
        logger.info(f"Loading configuration from {config_path}")
        with open(config_path, 'r') as config_file:
            self.config = json.load(config_file)
        
        # Initialize Spark session
        logger.info("Initializing Spark session")
        self.spark = create_spark_session(app_name=self.config.get("app_name", "Banking ETL Pipeline"))
        
        # Initialize components
        self._init_components()
        
        # Set execution date
        self.execution_date = datetime.now().strftime("%Y-%m-%d")
        
    def _init_components(self):
        """Initialize pipeline components based on configuration."""
        logger.info("Initializing pipeline components")
        
        # Initialize data connectors
        s3_config = self.config.get("s3", {})
        self.s3_connector = S3Connector(
            self.spark, 
            s3_config.get("bucket_name", "banking-data-lake")
        )
        
        rds_config = self.config.get("rds", {})
        self.rds_connector = RDSConnector(
            self.spark,
            rds_config.get("jdbc_url"),
            rds_config.get("username"),
            rds_config.get("password")
        )
        
        # Initialize transformers
        self.customer_transformer = CustomerTransformer(self.spark)
        self.transaction_transformer = TransactionTransformer(self.spark)
        self.account_transformer = AccountTransformer(self.spark)
        
        # Initialize data quality checker
        self.data_quality_checker = DataQualityChecker(self.spark)
        
        # Initialize data loaders
        redshift_config = self.config.get("redshift", {})
        self.redshift_loader = RedshiftLoader(
            self.spark,
            redshift_config.get("jdbc_url"),
            redshift_config.get("username"),
            redshift_config.get("password")
        )
        
        self.s3_loader = S3Loader(
            self.spark,
            s3_config.get("bucket_name", "banking-data-lake")
        )
        
    def run_customer_pipeline(self):
        """Run the customer data pipeline."""
        logger.info("Running customer data pipeline")
        
        try:
            # Extract customer data
            customer_config = self.config.get("pipelines", {}).get("customer", {})
            source_type = customer_config.get("source_type")
            
            if source_type == "s3":
                raw_customers = self.s3_connector.read_csv(
                    customer_config.get("source_path")
                )
            elif source_type == "rds":
                raw_customers = self.rds_connector.read_table(
                    customer_config.get("source_table")
                )
            else:
                raise ValueError(f"Unsupported source type: {source_type}")
            
            # Transform customer data
            cleaned_customers = self.customer_transformer.clean_customer_data(raw_customers)
            enriched_customers = self.customer_transformer.enrich_customer_data(cleaned_customers)
            
            # Run data quality checks
            quality_results = self.data_quality_checker.run_all_checks(
                enriched_customers,
                customer_config.get("data_quality", {})
            )
            
            if not quality_results.get("overall_passed", False):
                logger.warning("Data quality checks failed for customer data")
                # Depending on configuration, we might still proceed
                if customer_config.get("fail_on_quality_check", True):
                    raise Exception("Data quality checks failed for customer data")
            
            # Load customer data
            target_type = customer_config.get("target_type")
            
            if target_type == "redshift":
                self.redshift_loader.load_with_staging(
                    enriched_customers,
                    customer_config.get("target_table"),
                    key_columns=customer_config.get("key_columns", ["customer_id"])
                )
            elif target_type == "s3":
                self.s3_loader.write_delta(
                    enriched_customers,
                    customer_config.get("target_path"),
                    mode=customer_config.get("write_mode", "overwrite"),
                    partition_cols=customer_config.get("partition_cols", [])
                )
            else:
                raise ValueError(f"Unsupported target type: {target_type}")
            
            logger.info("Customer data pipeline completed successfully")
            return True
        except Exception as e:
            logger.error(f"Error in customer data pipeline: {str(e)}")
            raise
    
    def run_transaction_pipeline(self):
        """Run the transaction data pipeline."""
        logger.info("Running transaction data pipeline")
        
        try:
            # Extract transaction data
            transaction_config = self.config.get("pipelines", {}).get("transaction", {})
            source_type = transaction_config.get("source_type")
            
            if source_type == "s3":
                raw_transactions = self.s3_connector.read_csv(
                    transaction_config.get("source_path")
                )
            elif source_type == "rds":
                raw_transactions = self.rds_connector.read_table(
                    transaction_config.get("source_table")
                )
            else:
                raise ValueError(f"Unsupported source type: {source_type}")
            
            # Transform transaction data
            cleaned_transactions = self.transaction_transformer.clean_transaction_data(raw_transactions)
            enriched_transactions = self.transaction_transformer.enrich_transaction_data(cleaned_transactions)
            transactions_with_metrics = self.transaction_transformer.calculate_transaction_metrics(enriched_transactions)
            final_transactions = self.transaction_transformer.detect_anomalies(transactions_with_metrics)
            
            # Run data quality checks
            quality_results = self.data_quality_checker.run_all_checks(
                final_transactions,
                transaction_config.get("data_quality", {})
            )
            
            if not quality_results.get("overall_passed", False):
                logger.warning("Data quality checks failed for transaction data")
                # Depending on configuration, we might still proceed
                if transaction_config.get("fail_on_quality_check", True):
                    raise Exception("Data quality checks failed for transaction data")
            
            # Load transaction data
            target_type = transaction_config.get("target_type")
            
            if target_type == "redshift":
                self.redshift_loader.load_with_staging(
                    final_transactions,
                    transaction_config.get("target_table"),
                    key_columns=transaction_config.get("key_columns", ["transaction_id"])
                )
            elif target_type == "s3":
                self.s3_loader.write_delta(
                    final_transactions,
                    transaction_config.get("target_path"),
                    mode=transaction_config.get("write_mode", "append"),
                    partition_cols=transaction_config.get("partition_cols", ["transaction_year", "transaction_month"])
                )
            else:
                raise ValueError(f"Unsupported target type: {target_type}")
            
            logger.info("Transaction data pipeline completed successfully")
            return True
        except Exception as e:
            logger.error(f"Error in transaction data pipeline: {str(e)}")
            raise
    
    def run_account_pipeline(self):
        """Run the account data pipeline."""
        logger.info("Running account data pipeline")
        
        try:
            # Extract account data
            account_config = self.config.get("pipelines", {}).get("account", {})
            source_type = account_config.get("source_type")
            
            if source_type == "s3":
                raw_accounts = self.s3_connector.read_csv(
                    account_config.get("source_path")
                )
            elif source_type == "rds":
                raw_accounts = self.rds_connector.read_table(
                    account_config.get("source_table")
                )
            else:
                raise ValueError(f"Unsupported source type: {source_type}")
            
            # Transform account data
            cleaned_accounts = self.account_transformer.clean_account_data(raw_accounts)
            enriched_accounts = self.account_transformer.enrich_account_data(cleaned_accounts)
            
            # Run data quality checks
            quality_results = self.data_quality_checker.run_all_checks(
                enriched_accounts,
                account_config.get("data_quality", {})
            )
            
            if not quality_results.get("overall_passed", False):
                logger.warning("Data quality checks failed for account data")
                # Depending on configuration, we might still proceed
                if account_config.get("fail_on_quality_check", True):
                    raise Exception("Data quality checks failed for account data")
            
            # Load account data
            target_type = account_config.get("target_type")
            
            if target_type == "redshift":
                self.redshift_loader.load_with_staging(
                    enriched_accounts,
                    account_config.get("target_table"),
                    key_columns=account_config.get("key_columns", ["account_id"])
                )
            elif target_type == "s3":
                self.s3_loader.write_delta(
                    enriched_accounts,
                    account_config.get("target_path"),
                    mode=account_config.get("write_mode", "overwrite"),
                    partition_cols=account_config.get("partition_cols", [])
                )
            else:
                raise ValueError(f"Unsupported target type: {target_type}")
            
            logger.info("Account data pipeline completed successfully")
            return True
        except Exception as e:
            logger.error(f"Error in account data pipeline: {str(e)}")
            raise
    
    def run_pipeline(self):
        """Run the complete ETL pipeline."""
        logger.info("Starting the banking ETL pipeline")
        
        try:
            # Run individual pipelines based on configuration
            pipelines_to_run = self.config.get("pipelines_to_run", [])
            
            if "customer" in pipelines_to_run:
                self.run_customer_pipeline()
            
            if "account" in pipelines_to_run:
                self.run_account_pipeline()
            
            if "transaction" in pipelines_to_run:
                self.run_transaction_pipeline()
            
            logger.info("Banking ETL pipeline completed successfully")
            return True
        except Exception as e:
            logger.error(f"Error in banking ETL pipeline: {str(e)}")
            raise
        finally:
            # Clean up resources
            logger.info("Cleaning up resources")
            self.spark.stop()

if __name__ == "__main__":
    # Get configuration path from environment variable or use default
    config_path = os.environ.get("ETL_CONFIG_PATH", "config/config.json")
    
    # Run the pipeline
    pipeline = BankingETLPipeline(config_path)
    pipeline.run_pipeline()

In [0]:
Explanation: The BankingETLPipeline class serves as the main orchestrator for the entire ETL process, tying together all the components we’ve built. It provides:

Initialization and Configuration:
- Loads configuration from a JSON file
- Sets up logging
- Initializes a Spark session
- Creates instances of all required components (connectors, transformers, loaders)
Pipeline Execution Methods:
- run_customer_pipeline(): Processes customer data through extraction, transformation, quality checks, and loading
- run_transaction_pipeline(): Handles transaction data with specialized transformations and anomaly detection
- run_account_pipeline(): Processes account data with appropriate transformations
- run_pipeline(): Orchestrates the execution of all sub-pipelines based on configuration
Error Handling and Resource Management:
- Implements comprehensive try/except blocks for error handling
- Ensures proper cleanup of resources in the finally block
- Provides detailed logging throughout the process
The class follows a modular design pattern, making it easy to:

Add new data pipelines
Modify existing pipelines
Configure which pipelines to run
Handle different source and target types
This orchestration layer is crucial for managing the complexity of a multi-domain ETL pipeline in a banking environment.

In [0]:
Sample Configuration File
{
  "app_name": "Banking ETL Pipeline",
  "environment": "production",
  "pipelines_to_run": ["customer", "account", "transaction"],
  
  "s3": {
    "bucket_name": "banking-data-lake",
    "region": "us-east-1"
  },
  
  "rds": {
    "jdbc_url": "jdbc:mysql://banking-db.cluster-xyz.us-east-1.rds.amazonaws.com:3306/banking",
    "username": "${RDS_USERNAME}",
    "password": "${RDS_PASSWORD}"
  },
  
  "redshift": {
    "jdbc_url": "jdbc:redshift://banking-warehouse.xyz.us-east-1.redshift.amazonaws.com:5439/banking",
    "username": "${REDSHIFT_USERNAME}",
    "password": "${REDSHIFT_PASSWORD}"
  },
  
  "pipelines": {
    "customer": {
      "source_type": "s3",
      "source_path": "raw/customers/",
      "target_type": "redshift",
      "target_table": "dim_customer",
      "key_columns": ["customer_id"],
      "fail_on_quality_check": true,
      "data_quality": {
        "table_name": "dim_customer",
        "required_columns": ["customer_id", "first_name", "last_name", "email"],
        "key_columns": ["customer_id"],
        "range_checks": {
          "credit_score": [300, 850]
        }
      }
    },
    
    "account": {
      "source_type": "rds",
      "source_table": "accounts",
      "target_type": "redshift",
      "target_table": "dim_account",
      "key_columns": ["account_id"],
      "fail_on_quality_check": true,
      "data_quality": {
        "table_name": "dim_account",
        "required_columns": ["account_id", "customer_id", "account_type", "open_date"],
        "key_columns": ["account_id"],
        "range_checks": {
          "balance": [0, 10000000],
          "interest_rate": [0, 30]
        }
      }
    },
    
    "transaction": {
      "source_type": "s3",
      "source_path": "raw/transactions/",
      "target_type": "s3",
      "target_path": "processed/transactions/",
      "write_mode": "append",
      "partition_cols": ["transaction_year", "transaction_month"],
      "fail_on_quality_check": false,
      "data_quality": {
        "table_name": "fact_transaction",
        "required_columns": ["transaction_id", "account_id", "transaction_date", "amount"],
        "key_columns": ["transaction_id"],
        "range_checks": {
          "amount": [0, 1000000]
        }
      }
    }
  }
}

In [0]:
Explanation: This configuration file provides a centralized way to control the ETL pipeline’s behavior without changing code. Key sections include:

General Settings:
- Application name and environment
- List of pipelines to run
Connection Information:
- S3 bucket details for data lake storage
- RDS connection parameters for relational database access
- Redshift connection parameters for data warehouse access
- Note the use of environment variable placeholders for sensitive credentials
Pipeline-Specific Configurations:
- Customer Pipeline: Reads from S3, loads to Redshift, with specific data quality checks
- Account Pipeline: Reads from RDS, loads to Redshift, with account-specific validations
- Transaction Pipeline: Reads from S3, writes to S3 Delta Lake with time-based partitioning
Each pipeline configuration includes:

Source and target specifications
Data quality requirements
Failure handling policies
Key columns for deduplication and merging
This configuration-driven approach allows for flexible deployment across environments and easy modification of pipeline behavior without code changes.

In [0]:
# 2. Generating Sample Data
# Create a notebook to generate sample data:

# Databricks notebook: Generate Sample Data

from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
import random
from datetime import datetime, timedelta
import uuid

# Initialize Spark session
spark = SparkSession.builder.getOrCreate()

# Define schemas
customer_schema = StructType([
    StructField("customer_id", StringType(), False),
    StructField("first_name", StringType(), True),
    StructField("last_name", StringType(), True),
    StructField("date_of_birth", DateType(), True),
    StructField("email", StringType(), True),
    StructField("phone_number", StringType(), True),
    StructField("address", StringType(), True),
    StructField("city", StringType(), True),
    StructField("state", StringType(), True),
    StructField("zip_code", StringType(), True),
    StructField("country", StringType(), True),
    StructField("customer_since", DateType(), True),
    StructField("credit_score", IntegerType(), True),
    StructField("risk_segment", StringType(), True)
])

account_schema = StructType([
    StructField("account_id", StringType(), False),
    StructField("customer_id", StringType(), False),
    StructField("account_type", StringType(), True),
    StructField("account_status", StringType(), True),
    StructField("open_date", DateType(), True),
    StructField("close_date", DateType(), True),
    StructField("currency", StringType(), True),
    StructField("branch_id", StringType(), True),
    StructField("interest_rate", FloatType(), True),
    StructField("balance", DecimalType(18, 2), True),
    StructField("last_activity_date", DateType(), True)
])

transaction_schema = StructType([
    StructField("transaction_id", StringType(), False),
    StructField("account_id", StringType(), False),
    StructField("transaction_date", TimestampType(), False),
    StructField("transaction_type", StringType(), True),
    StructField("amount", DecimalType(18, 2), True),
    StructField("currency", StringType(), True),
    StructField("description", StringType(), True),
    StructField("merchant_name", StringType(), True),
    StructField("merchant_category", StringType(), True),
    StructField("transaction_status", StringType(), True),
    StructField("channel", StringType(), True),
    StructField("location", StringType(), True),
    StructField("is_international", BooleanType(), True)
])

# Generate customer data
def generate_customers(num_customers=1000):
    first_names = ["James", "Mary", "John", "Patricia", "Robert", "Jennifer", "Michael", "Linda", "William", "Elizabeth"]
    last_names = ["Smith", "Johnson", "Williams", "Jones", "Brown", "Davis", "Miller", "Wilson", "Moore", "Taylor"]
    states = ["CA", "NY", "TX", "FL", "IL", "PA", "OH", "GA", "NC", "MI"]
    cities = ["Los Angeles", "New York", "Houston", "Miami", "Chicago", "Philadelphia", "Columbus", "Atlanta", "Charlotte", "Detroit"]
    risk_segments = ["Low", "Medium", "High"]
    
    customers = []
    
    for i in range(num_customers):
        customer_id = f"CUST{i:06d}"
        first_name = random.choice(first_names)
        last_name = random.choice(last_names)
        
        # Generate date of birth (21-80 years old)
        years_ago = random.randint(21, 80)
        dob = datetime.now() - timedelta(days=365 * years_ago)
        
        # Generate customer since date (0-10 years ago)
        years_customer = random.randint(0, 10)
        customer_since = datetime.now() - timedelta(days=365 * years_customer)
        
        state = random.choice(states)
        city = random.choice(cities)
        
        customers.append((
            customer_id,
            first_name,
            last_name,
            dob.date(),
            f"{first_name.lower()}.{last_name.lower()}@example.com",
            f"555-{random.randint(100, 999)}-{random.randint(1000, 9999)}",
            f"{random.randint(100, 9999)} Main St",
            city,
            state,
            f"{random.randint(10000, 99999)}",
            "USA",
            customer_since.date(),
            random.randint(300, 850),
            random.choice(risk_segments)
        ))
    
    return spark.createDataFrame(customers, customer_schema)

import decimal

def generate_accounts(customers_df, num_accounts=1500):
    account_types = ["checking", "savings", "investment"]
    account_statuses = ["active", "closed", "suspended"]
    currencies = ["USD", "EUR", "GBP"]
    
    accounts = []
    customer_ids = [row.customer_id for row in customers_df.select("customer_id").collect()]
    
    for i in range(num_accounts):
        account_id = f"ACC{i:08d}"
        customer_id = random.choice(customer_ids)
        account_type = random.choice(account_types)
        account_status = random.choice(account_statuses)
        years_ago = random.randint(0, 5)
        open_date = datetime.now() - timedelta(days=365 * years_ago)
        close_date = None
        if account_status == "closed":
            days_ago = random.randint(0, 365)
            close_date = datetime.now() - timedelta(days=days_ago)
        days_ago = random.randint(0, 30)
        last_activity_date = datetime.now() - timedelta(days=days_ago)
        
        # Fix: Convert balance to decimal.Decimal
        balance = decimal.Decimal(str(round(random.uniform(0, 100000), 2)))
        
        accounts.append((
            account_id,
            customer_id,
            account_type,
            account_status,
            open_date.date(),
            close_date,
            random.choice(currencies),
            f"BR{random.randint(100, 999)}",
            random.uniform(0.01, 5.0),
            balance,
            last_activity_date.date()
        ))
    
    return spark.createDataFrame(accounts, account_schema)

# Generate transaction data
def generate_transactions(accounts_df, num_transactions=10000):
    transaction_types = ["deposit", "withdrawal", "transfer", "payment"]
    currencies = ["USD", "EUR", "GBP"]
    merchant_categories = ["grocery", "restaurant", "retail", "travel", "utility", "entertainment"]
    transaction_statuses = ["completed", "pending", "failed", "reversed"]
    channels = ["online", "mobile", "branch", "atm"]
    locations = ["USA", "Canada", "UK", "France", "Germany", "Japan", "Australia", "Brazil", "Mexico", "China"]
    
    transactions = []
    
    # Get active account IDs
    account_ids = [row.account_id for row in accounts_df.filter(col("account_status") == "active").select("account_id").collect()]
    
    for i in range(num_transactions):
        transaction_id = str(uuid.uuid4())
        account_id = random.choice(account_ids)
        
        # Generate transaction date (last 90 days)
        days_ago = random.randint(0, 90)
        hours_ago = random.randint(0, 24)
        minutes_ago = random.randint(0, 60)
        transaction_date = datetime.now() - timedelta(days=days_ago, hours=hours_ago, minutes=minutes_ago)
        
        transaction_type = random.choice(transaction_types)
        
        # Amount based on transaction type
        if transaction_type == "deposit":
            amount = random.uniform(10, 5000)
        elif transaction_type == "withdrawal":
            amount = random.uniform(10, 1000)
        elif transaction_type == "transfer":
            amount = random.uniform(10, 3000)
        else:  # payment
            amount = random.uniform(10, 2000)
        
        currency = random.choice(currencies)
        merchant_category = random.choice(merchant_categories)
        
        # Generate merchant name based on category
        if merchant_category == "grocery":
            merchant_name = random.choice(["Whole Foods", "Safeway", "Kroger", "Trader Joe's"])
        elif merchant_category == "restaurant":
            merchant_name = random.choice(["McDonald's", "Starbucks", "Chipotle", "Olive Garden"])
        elif merchant_category == "retail":
            merchant_name = random.choice(["Amazon", "Walmart", "Target", "Best Buy"])
        elif merchant_category == "travel":
            merchant_name = random.choice(["Delta Airlines", "Marriott", "Expedia", "Uber"])
        elif merchant_category == "utility":
            merchant_name = random.choice(["AT&T", "PG&E", "Comcast", "Verizon"])
        else:  # entertainment
            merchant_name = random.choice(["Netflix", "AMC Theaters", "Spotify", "Disney+"])
        
        location = random.choice(locations)
        is_international = location != "USA"
        
        transactions.append((
            transaction_id,
            account_id,
            transaction_date,
            transaction_type,
            amount,
            currency,
            f"{transaction_type.capitalize()} at {merchant_name}",
            merchant_name,
            merchant_category,
            random.choice(transaction_statuses),
            random.choice(channels),
            location,
            is_international
        ))
    
    return spark.createDataFrame(transactions, transaction_schema)

# Generate the data
customers_df = generate_customers(1000)
accounts_df = generate_accounts(customers_df, 1500)
transactions_df = generate_transactions(accounts_df, 10000)

# Write data to S3
customers_df.write.mode("overwrite").csv("s3://banking-data-lake/raw/customers/", header=True)
accounts_df.write.mode("overwrite").csv("s3://banking-data-lake/raw/accounts/", header=True)
transactions_df.write.mode("overwrite").csv("s3://banking-data-lake/raw/transactions/", header=True)

print("Sample data generation complete!")

In [0]:
# Databricks notebook: Run ETL Pipeline

# Import required modules
import sys
import os

# Add project directory to Python path
project_path = "/dbfs/FileStore/banking-etl-pipeline"
sys.path.append(project_path)

# Import the main pipeline class
from src.orchestration.main_pipeline import BankingETLPipeline

# Set configuration path
config_path = os.path.join(project_path, "config/config.json")

# Initialize and run the pipeline
pipeline = BankingETLPipeline(config_path)
pipeline.run_pipeline()

In [0]:
4. Monitoring and Troubleshooting
Create a notebook for monitoring and troubleshooting:

# Databricks notebook: Monitor ETL Pipeline

from pyspark.sql import SparkSession
from pyspark.sql.functions import *

# Initialize Spark session
spark = SparkSession.builder.getOrCreate()

# Check data quality results
def check_data_quality_results():
    try:
        quality_results = spark.read.format("delta").load("s3://banking-data-lake/monitoring/data_quality_results/")
        
        # Show the latest results
        latest_results = quality_results.orderBy(col("execution_date").desc()).limit(10)
        display(latest_results)
        
        # Show failed checks
        failed_checks = quality_results.filter(col("overall_passed") == False)
        print(f"Number of failed quality checks: {failed_checks.count()}")
        display(failed_checks)
    except Exception as e:
        print(f"Error checking data quality results: {str(e)}")

# Check pipeline execution logs
def check_pipeline_logs():
    try:
        logs = spark.read.text("s3://banking-data-lake/logs/")
        
        # Filter for errors
        error_logs = logs.filter(col("value").contains("ERROR"))
        print(f"Number of error logs: {error_logs.count()}")
        display(error_logs)
    except Exception as e:
        print(f"Error checking pipeline logs: {str(e)}")

# Check data counts
def check_data_counts():
    try:
        # Check customer counts
        customer_count = spark.read.format("jdbc") \
            .option("url", "jdbc:redshift://banking-warehouse.xyz.us-east-1.redshift.amazonaws.com:5439/banking") \
            .option("dbtable", "dim_customer") \
            .option("user", dbutils.secrets.get("redshift", "username")) \
            .option("password", dbutils.secrets.get("redshift", "password")) \
            .option("driver", "com.amazon.redshift.jdbc42.Driver") \
            .load() \
            .count()
        
        print(f"Customer count in Redshift: {customer_count}")
        
        # Check account counts
        account_count = spark.read.format("jdbc") \
            .option("url", "jdbc:redshift://banking-warehouse.xyz.us-east-1.redshift.amazonaws.com:5439/banking") \
            .option("dbtable", "dim_account") \
            .option("user", dbutils.secrets.get("redshift", "username")) \
            .option("password", dbutils.secrets.get("redshift", "password")) \
            .option("driver", "com.amazon.redshift.jdbc42.Driver") \
            .load() \
            .count()
        
        print(f"Account count in Redshift: {account_count}")
        
        # Check transaction counts
        transaction_count = spark.read.format("delta") \
            .load("s3://banking-data-lake/processed/transactions/") \
            .count()
        
        print(f"Transaction count in Delta Lake: {transaction_count}")
    except Exception as e:
        print(f"Error checking data counts: {str(e)}")

# Run the monitoring functions
print("=== Data Quality Results ===")
check_data_quality_results()

print("\n=== Pipeline Logs ===")
check_pipeline_logs()

print("\n=== Data Counts ===")
check_data_counts()

In [0]:
Explanation: This monitoring notebook provides essential visibility into the ETL pipeline’s operation:

Data Quality Monitoring:
- Reads data quality check results from the monitoring location
- Shows the most recent results
- Highlights any failed quality checks
Log Analysis:
- Reads pipeline execution logs
- Filters for error messages
- Displays count and details of errors
Data Reconciliation:
- Counts records in target systems (Redshift and Delta Lake)
- Helps identify potential data loss or duplication issues
- Provides a quick sanity check on pipeline results
This monitoring approach is crucial for:

Detecting issues early
Ensuring data quality
Validating successful data processing
Troubleshooting pipeline failures
For a production environment, these checks could be automated and integrated with alerting systems.

In [0]:
. Scheduling the Pipeline
To schedule the ETL pipeline in Databricks:

Create a Databricks job:
- Go to the Databricks workspace and click on “Jobs”
- Click “Create Job”
- Add the “Run ETL Pipeline” notebook as a task
- Configure the cluster to use
- Set up a schedule (e.g., daily at 2 AM)
- Configure email notifications for success/failure
Set up monitoring alerts:
- Add the “Monitor ETL Pipeline” notebook as a separate task
- Configure it to run after the main pipeline task
- Set up email notifications for critical issues
Explanation: Scheduling ensures that our ETL pipeline runs automatically at the appropriate times:

Job Configuration:
- The main pipeline task processes the data
- The monitoring task validates results and checks for issues
- Email notifications alert the team to any problems
Scheduling Considerations:
- Schedule during off-peak hours to minimize impact on source systems
- Allow sufficient time for the pipeline to complete before business hours
- Consider dependencies on other data processes
Advanced Scheduling:
- For more complex workflows, consider using Databricks Workflows
- Implement retry logic for transient failures
- Set up SLA monitoring for timely completion
Proper scheduling is essential for maintaining up-to-date data in the banking environment where timely information is critical for decision-making and reporting.

Conclusion
In this comprehensive guide, we’ve built an end-to-end ETL pipeline for a banking domain using AWS, PySpark, and Databricks. The pipeline includes:

Data Ingestion: Reading data from various sources like S3 and RDS
Data Transformation: Cleaning, enriching, and validating data
Data Quality: Implementing robust data quality checks
Data Loading: Loading processed data to Redshift and S3
Monitoring and Alerting: Setting up monitoring for the pipeline
This architecture provides a scalable, maintainable, and production-ready solution for processing banking data. The modular design allows for easy extension and modification as business requirements evolve.

Key benefits of this implementation:

Scalability: Leverages Spark’s distributed computing capabilities to handle growing data volumes
Reliability: Includes error handling, data quality checks, and monitoring to ensure dependable operation
Maintainability: Modular design with clear separation of concerns makes the codebase easier to maintain
Security: Implements secure data handling practices essential for sensitive banking information
Flexibility: Supports multiple data sources and targets with a configuration-driven approach
The pipeline we’ve built addresses common challenges in banking data processing:

Handling diverse data types and sources
Ensuring data quality and consistency
Detecting anomalous transactions
Maintaining historical data with proper partitioning
Providing a foundation for analytics and reporting
By following this guide, you can implement a robust ETL pipeline for your banking data processing needs, enabling advanced analytics, regulatory reporting, and data-driven decision making that are essential in today’s competitive banking industry.