In [1]:
import time

In [None]:
from pyspark.sql import SparkSession

# Stop existing Spark session if it exists
if SparkSession.getActiveSession():
    SparkSession.getActiveSession().stop()

spark = SparkSession.builder \
    .appName("Takumi ETL") \
    .master("local[2]") \
    .config("spark.driver.memory", "5g").config("spark.executor.memory", "3g").config("spark.sql.shuffle.partitions", "4") \
    .config("spark.hadoop.io.nativeio.disable", "true") \
    .getOrCreate()


In [3]:
date = "2025-03-25"
start_time = time.time()
file_path = f"input_data/input_data_2025-03-25/input_data_2025-03-25_batch_1.parquet"

input_data = spark.read.parquet(file_path)
symbol_data = spark.read.parquet(f"reference_market_data/ref_market_data_{date}.parquet")
currency_data = spark.read.parquet("reference_data/ref_currency_data.parquet")
exchange_data = spark.read.parquet("reference_data/ref_exchange_data.parquet")
order_types_data = spark.read.parquet("reference_data/ref_order_types_data.parquet")
sides_data = spark.read.parquet("reference_data/ref_sides_data.parquet")
transaction_types_data = spark.read.parquet("reference_data/ref_transaction_types_data.parquet")
order_statuses_data = spark.read.parquet("reference_data/ref_order_statuses_data.parquet")
mics_data = spark.read.parquet("reference_data/ref_mics_data.parquet")
timing_data = spark.read.parquet("reference_data/ref_market_timing_data.parquet")
end_time = time.time()

execution_time = end_time - start_time
print(f"Execution Time for Input_data Reading: {execution_time:.6f} seconds")


Execution Time for Input_data Reading: 8.347760 seconds


In [None]:
input_data.show()

In [4]:
import yaml
import time
import os
from pyspark.sql import functions as F
from pyspark.sql.types import StringType, DoubleType

class Validator:
    def __init__(self, df, config_path):
        self.df = df.withColumn("validation_flag", F.lit(""))

        # Load config
        with open(config_path, "r") as file:
            self.config = yaml.safe_load(file)

        # Define expected data types
        dtype_map = {
            "string": StringType(),
            "double": DoubleType()
        }
        self.expected_dtypes = {
            col: dtype_map[self.config["validation"]["expected_dtypes"][col]]
            for col in self.config["validation"].get("expected_dtypes", {})
            if self.config["validation"]["expected_dtypes"][col] in dtype_map
        }

    def add_flag(self, condition, issue):
        """Appends an issue to the validation_flag column."""
        self.df = self.df.withColumn(
            "validation_flag",
            F.when(condition, F.concat_ws(", ", F.col("validation_flag"), F.lit(issue)))
            .otherwise(F.col("validation_flag"))
        )

    def check_missing_values(self):
        """Checks for missing values in required columns."""
        for col in self.config["validation"].get("required_columns", []):
            self.add_flag(F.col(col).isNull(), f"{col}_missing")

    def check_data_types(self):
        """Validates column data types."""
        for col, expected_type in self.expected_dtypes.items():
            if col in self.df.columns:
                self.add_flag(
                    F.col(col).cast(expected_type) != F.col(col), f"{col}_dtype_mismatch"
                )

    def fix_categorical(self):
        """Replaces invalid categorical values with a default value."""
        for col, settings in self.config["validation"].get("categorical", {}).items():
            valid_values = settings["valid_values"]
            default = settings["default"]
            valid_values_expr = F.when(F.col(col).isin(valid_values), F.col(col)).otherwise(default)
            self.df = self.df.withColumn(col, valid_values_expr)

    def fix_regex(self):
        """Validates a column against a regex pattern and replaces invalid values."""
        for col, settings in self.config["validation"].get("regex", {}).items():
            pattern = settings["pattern"]
            replacement = settings["replacement"]
            mask = F.col(col).rlike(pattern)
            self.df = self.df.withColumn(col, F.when(mask, F.col(col)).otherwise(replacement))

    def run_validations(self):
        """Executes all validation steps."""
        self.check_missing_values()
        self.check_data_types()
        self.fix_categorical()
        self.fix_regex()
        return self.df

# Usage Example:
start_time = time.time()
config_path = "configurations/validation_configurations.yaml"  # Path to the YAML file
validator = Validator(input_data, config_path)
validated_data = validator.run_validations()
validated_data.count()
print(validated_data.count())
end_time = time.time()

# Calculate execution time
execution_time = end_time - start_time
print(f"Execution Time for Validating data: {execution_time:.6f} seconds")


1000
Execution Time for Validating data: 3.932632 seconds


In [5]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import row_number, current_timestamp, col
from pyspark.sql.window import Window
import os 
import time

def generate_output_data(ref_data, validated_data, symbol_data, files_per_day, date):
    spark = validated_data.sparkSession  

    (
        currency_data, exchange_data, order_types_data, sides_data,
        transaction_types_data, order_statuses_data, mics_data
    ) = ref_data
    
    # Join validated_data with reference tables
    output_data = (
        validated_data
        .join(transaction_types_data.selectExpr("transaction_type", "transaction_type_id"), "transaction_type", "left")
        .join(mics_data.selectExpr("mic_code", "mic_id"), "mic_code", "left")
        .join(order_statuses_data.selectExpr("order_status", "order_status_id"), "order_status", "left")
        .join(sides_data.selectExpr("side", "side_id"), "side", "left")
        .join(order_types_data.selectExpr("order_type_name as order_type", "order_type_id"), "order_type", "left")
        .join(currency_data.selectExpr("currency_name", "currency_id"), "currency_name", "left")
        .join(exchange_data.selectExpr("exchange_code", "exchange_id"), "exchange_code", "left")
        .join(symbol_data.selectExpr("symbol", "listing_internal_id"), "symbol", "left")
    )

    # Define column order
    column_order = [
        "transaction_id", "transaction_parent_id",
        "transaction_timestamp", "transaction_type_id", "mic_id",
        "order_status_id", "side_id", "order_type_id", "symbol",
        "isin", "price", "quantity", "trader_id", "broker_id",
        "exchange_id", "currency_id", "listing_internal_id",
        "creation_time", "last_update_time", "validation_flag"
    ]

    # Add timestamps
    output_data = (
        output_data
        .withColumn("creation_time", current_timestamp())
        .withColumn("last_update_time", current_timestamp())
        .select(*column_order)
    )

    # Repartition data
    output_data = output_data.repartition(files_per_day)

    # Define output path
    output_path = f"output_data"
    
    try:
        # Write the output data to the specified path
        output_data.write \
            .mode("overwrite") \
            .option("compression", "snappy") \
            .parquet(output_path)
        

        print("Output data written successfully.")
    except Exception as e:
        print(f"Error writing output data: {e}")

    return output_data

# Start time
start_time = time.time()

# Reference data
ref_data = (
    currency_data,
    exchange_data,
    order_types_data,
    sides_data,
    transaction_types_data,
    order_statuses_data,
    mics_data,
)

# Create output directory if it doesn't exist
os.makedirs("output_data", exist_ok=True)

# Generate output data
output_data = generate_output_data(ref_data, validated_data, symbol_data, 20, date)

# End time
end_time = time.time()

# Calculate execution time
execution_time = end_time - start_time
print(f"Execution Time for output data generation: {execution_time:.2f} seconds")

Output data written successfully.
Execution Time for output data generation: 6.70 seconds


In [None]:
output_data.show()

In [6]:
spark

In [25]:
# import boto3
# import os
# import time

# start = time.time()

# s3 = boto3.client(
#     's3',
#     aws_access_key_id="",
#     aws_secret_access_key="",
#     region_name=""
# )

# local_folder = "output_data"
# s3_bucket = "output-data-dump-bucket"
# s3_prefix = "parquet/"  

# for file in os.listdir(local_folder):
#     if file.endswith(".snappy.parquet") and not file.startswith(("_", ".")):  
#         s3.upload_file(os.path.join(local_folder, file), s3_bucket, s3_prefix + file)
#         print(f"Uploaded {file} to s3://{s3_bucket}/{s3_prefix}")

# end = time.time()
# execution_time = end - start
# print(f"Execution time for dumping files in s3 bucket: {execution_time:.6f} seconds")


In [26]:
# print(f"Column Count: {len(output_data.columns)}")

# from pyspark.sql.functions import col,sum
# output_data.select([sum(col(c).isNull().cast("int")).alias(c) for c in output_data.columns]).show()

# output_data.dtypes

In [27]:
from pyspark.sql.functions import col

# Select a sample of rows from validated_data
sample_rows = validated_data.sample(withReplacement=False, fraction=0.0001).limit(5)
print(sample_rows.count())
# display(sample_rows.limit(10))
sample_rows.select("transaction_id","transaction_timestamp","transaction_type","validation_flag").orderBy("transaction_timestamp").show()

# Extract transaction IDs from the sampled rows for filtering output_data
sample_transaction_ids = [row["transaction_id"] for row in sample_rows.collect()]

# Filter output_data based on selected transaction IDs
filtered_output_data = output_data.filter(col("transaction_id").isin(sample_transaction_ids))
print(filtered_output_data.count())
# display(filtered_output_data.limit(10))
filtered_output_data.select("transaction_id","transaction_timestamp","transaction_type_id","validation_flag").orderBy("transaction_timestamp").show()




0
+--------------+---------------------+----------------+---------------+
|transaction_id|transaction_timestamp|transaction_type|validation_flag|
+--------------+---------------------+----------------+---------------+
+--------------+---------------------+----------------+---------------+

0
+--------------+---------------------+-------------------+---------------+
|transaction_id|transaction_timestamp|transaction_type_id|validation_flag|
+--------------+---------------------+-------------------+---------------+
+--------------+---------------------+-------------------+---------------+

