In [None]:
from pyspark.sql import SparkSession, Window
from pyspark.sql import functions as F
from pyspark.sql import types as T
from datetime import datetime
import json
import hashlib

class F1SilverLayer:
    def __init__(self,
                 spark,
                 bronze_path="/content/sample_data/bronze/",
                 silver_path="/content/sample_data/silver/",
                 checkpoint_path="/content/sample_data/silver_checkpoint/"):
        """
        Initialize the F1 Silver Layer processor
        """
        self.spark = spark
        self.bronze_path = bronze_path
        self.silver_path = silver_path
        self.checkpoint_path = checkpoint_path

        # Configure Spark for optimal performance
        self.spark.conf.set("spark.sql.parquet.compression.codec", "zstd")
        self.spark.conf.set("spark.sql.parquet.enableDictionaryEncoding", "true")
        self.spark.conf.set("spark.sql.parquet.block.size", 256*1024*1024)

        # Define schema for better control and performance
        self.define_schemas()

    def define_schemas(self):
        """Define explicit schemas for the silver layer"""
        self.circuit_schema = T.StructType([
            T.StructField("circuitId", T.StringType(), True),
            T.StructField("circuitName", T.StringType(), True),
            T.StructField("lat", T.DoubleType(), True),
            T.StructField("long", T.DoubleType(), True),
            T.StructField("locality", T.StringType(), True),
            T.StructField("country", T.StringType(), True)
        ])

        self.result_schema = T.StructType([
            T.StructField("constructorId", T.StringType(), True),
            T.StructField("constructorName", T.StringType(), True),
            T.StructField("driverId", T.StringType(), True),
            T.StructField("driverName", T.StringType(), True),
            T.StructField("position", T.IntegerType(), True),
            T.StructField("points", T.DoubleType(), True),
            T.StructField("grid", T.IntegerType(), True),
            T.StructField("laps", T.IntegerType(), True),
            T.StructField("status", T.StringType(), True),
            T.StructField("time", T.StringType(), True)
        ])

    def get_last_processed_timestamp(self):
        """Get the last processed timestamp from checkpoint"""
        try:
            checkpoint_df = self.spark.read.parquet(self.checkpoint_path)
            last_checkpoint = checkpoint_df.orderBy(F.col("processed_timestamp").desc()).first()
            return last_checkpoint.processed_timestamp
        except:
            return "1900-01-01T00:00:00"

    def process_bronze_data(self):
        """Process bronze data incrementally"""
        # Read bronze data
        bronze_df = self.spark.read.json(f"{self.bronze_path}/season=*")

        # Get last processed timestamp
        last_processed = self.get_last_processed_timestamp()

        # Filter for new records
        incremental_df = bronze_df.filter(
            F.col("ingestion_timestamp") > last_processed
        )

        if incremental_df.count() == 0:
            print("No new data to process")
            return None

        return self.transform_bronze_to_silver(incremental_df)

    def transform_bronze_to_silver(self, df):
        """Transform bronze data into silver format with quality checks"""

        # 1. Explode nested structures
        df = df.select(
            F.col("season"),
            F.col("round"),
            F.col("raceName"),
            F.col("date"),
            F.col("time"),
            F.to_timestamp(F.col("ingestion_timestamp")).alias("ingestion_timestamp"),
            F.col("url"),
            F.col("Circuit").alias("circuit"),
            F.explode("Results").alias("result")
        )

        # 2. Flatten nested structures
        df = df.select(
            "*",
            F.col("circuit.circuitId").alias("circuit_id"),
            F.col("circuit.circuitName").alias("circuit_name"),
            F.col("circuit.Location.lat").alias("circuit_lat"),
            F.col("circuit.Location.long").alias("circuit_long"),
            F.col("circuit.Location.locality").alias("circuit_locality"),
            F.col("circuit.Location.country").alias("circuit_country"),
            F.col("result.Constructor.constructorId").alias("constructor_id"),
            F.col("result.Constructor.name").alias("constructor_name"),
            F.col("result.Driver.driverId").alias("driver_id"),
            F.col("result.Driver.givenName").alias("driver_given_name"),
            F.col("result.Driver.familyName").alias("driver_family_name"),
            F.col("result.position").alias("position"),
            F.col("result.points").alias("points"),
            F.col("result.grid").alias("grid"),
            F.col("result.laps").alias("laps"),
            F.col("result.status").alias("status"),
            F.col("result.Time.time").alias("finish_time")
        ).drop("circuit", "result")

        # 3. Data type conversions and standardization
        df = df.withColumn("race_timestamp",
                          F.to_timestamp(
                              F.concat(F.col("date"), F.lit(" "), F.col("time")),
                              "yyyy-MM-dd HH:mm:ssX"
                          ))

        # 4. Add computed columns
        df = df.withColumn("driver_full_name",
                          F.concat(F.col("driver_given_name"), F.lit(" "), F.col("driver_family_name")))

        # 5. Generate hash key for change detection
        columns_for_hash = ["season", "round", "driver_id", "constructor_id", "position"]
        df = df.withColumn("row_hash",
                          F.sha2(F.concat_ws("|", *[F.col(c) for c in columns_for_hash]), 256))

        # 6. Add data quality columns
        df = self.add_data_quality_checks(df)

        # 7. Add metadata columns
        df = df.withColumn("processed_timestamp", F.current_timestamp())
        df = df.withColumn("silver_batch_id", F.uuid())

        return df

    def add_data_quality_checks(self, df):
        """Add data quality check columns"""

        # Define quality checks
        df = df.withColumn("is_valid_position",
                          (F.col("position").isNotNull() & (F.col("position") >= 1)))

        df = df.withColumn("is_valid_points",
                          (F.col("points").isNotNull() & (F.col("points") >= 0)))

        df = df.withColumn("is_valid_grid",
                          (F.col("grid").isNotNull() & (F.col("grid") >= 0)))

        df = df.withColumn("is_valid_date",
                          F.col("race_timestamp").isNotNull())

        # Combine all checks
        df = df.withColumn("is_valid_record",
                          F.col("is_valid_position") &
                          F.col("is_valid_points") &
                          F.col("is_valid_grid") &
                          F.col("is_valid_date"))

        # Calculate null percentages
        for column in df.columns:
            df = df.withColumn(f"is_null_{column}",
                             F.when(F.col(column).isNull(), 1).otherwise(0))

        return df

    def write_to_silver(self, df):
        """Write processed data to silver layer"""
        if df is None or df.count() == 0:
            return

        # Write main silver table
        (df.write
         .mode("append")
         .partitionBy("season")
         .format("parquet")
         .option("compression", "zstd")
         .save(self.silver_path))

        # Write quality metrics
        quality_metrics = self.calculate_quality_metrics(df)
        (quality_metrics.write
         .mode("append")
         .format("parquet")
         .save(f"{self.silver_path}_metrics"))

        # Update checkpoint
        self.update_checkpoint(df)

    def calculate_quality_metrics(self, df):
        """Calculate quality metrics for the batch"""
        metrics = []

        # Calculate null percentages
        for column in df.columns:
            null_count = df.filter(F.col(column).isNull()).count()
            total_count = df.count()
            null_percentage = (null_count / total_count) * 100 if total_count > 0 else 0
            metrics.append({
                "metric_name": f"null_percentage_{column}",
                "metric_value": null_percentage,
                "batch_id": df.select(F.first("silver_batch_id")).collect()[0][0],
                "calculated_at": datetime.now().isoformat()
            })

        # Add other metrics
        metrics.extend([
            {
                "metric_name": "total_records",
                "metric_value": df.count(),
                "batch_id": df.select(F.first("silver_batch_id")).collect()[0][0],
                "calculated_at": datetime.now().isoformat()
            },
            {
                "metric_name": "invalid_records_percentage",
                "metric_value": (df.filter(~F.col("is_valid_record")).count() / df.count()) * 100,
                "batch_id": df.select(F.first("silver_batch_id")).collect()[0][0],
                "calculated_at": datetime.now().isoformat()
            }
        ])

        return self.spark.createDataFrame(metrics)

    def update_checkpoint(self, df):
        """Update the checkpoint with latest processed timestamp"""
        checkpoint_data = [{
            "processed_timestamp": df.select(F.max("ingestion_timestamp")).collect()[0][0],
            "silver_batch_id": df.select(F.first("silver_batch_id")).collect()[0][0],
            "record_count": df.count(),
            "checkpoint_timestamp": datetime.now().isoformat()
        }]

        (self.spark.createDataFrame(checkpoint_data)
         .write
         .mode("append")
         .parquet(self.checkpoint_path))

    def process(self):
        """Main processing method"""
        try:
            print("Starting Silver Layer Processing...")

            # Process bronze to silver
            silver_df = self.process_bronze_data()

            # Write to silver if we have data
            if silver_df is not None:
                print(f"Writing {silver_df.count()} records to silver layer...")
                self.write_to_silver(silver_df)
                print("Silver Layer Processing Complete!")
            else:
                print("No new data to process")

        except Exception as e:
            print(f"Error in Silver Layer Processing: {str(e)}")
            raise

# Usage
if __name__ == "__main__":
    spark = SparkSession.builder \
        .appName("F1SilverLayer") \
        .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
        .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
        .getOrCreate()

    silver_layer = F1SilverLayer(spark)
    silver_layer.process()

In [None]:
from pyspark.sql import SparkSession, Window
from pyspark.sql import functions as F
from pyspark.sql import types as T
from datetime import datetime
import json
import hashlib

class F1SilverLayer:
    def __init__(self,
                 spark,
                 bronze_path="/content/sample_data/bronze/",
                 silver_path="/content/sample_data/silver/",
                 checkpoint_path="/content/sample_data/silver_checkpoint/"):
        """
        Initialize the F1 Silver Layer processor
        """
        self.spark = spark
        self.bronze_path = bronze_path
        self.silver_path = silver_path
        self.checkpoint_path = checkpoint_path

        # Configure Spark for optimal performance
        self.spark.conf.set("spark.sql.parquet.compression.codec", "snappy")
        self.spark.conf.set("spark.sql.parquet.enableDictionaryEncoding", "true")
        self.spark.conf.set("spark.sql.parquet.block.size", 256*1024*1024)

    def get_last_processed_timestamp(self):
        """Get the last processed timestamp from checkpoint"""
        try:
            checkpoint_df = self.spark.read.parquet(self.checkpoint_path)
            last_checkpoint = checkpoint_df.orderBy(F.col("processed_timestamp").desc()).first()
            return last_checkpoint.processed_timestamp
        except:
            return "1900-01-01T00:00:00"

    def process_bronze_data(self):
        """Process bronze data incrementally"""
        # Read bronze data
        bronze_df = spark.read.json('/content/sample_data/bronze/season=*')

        # Get last processed timestamp
        last_processed = self.get_last_processed_timestamp()

        # Filter for new records
        incremental_df = bronze_df.filter(
            F.col("ingestion_timestamp") > last_processed
        )

        if incremental_df.count() == 0:
            print("No new data to process")
            return None

        return self.transform_bronze_to_silver(incremental_df)

    def transform_bronze_to_silver(self, df):
        """Transform bronze data into silver format with quality checks"""

        # 1. Explode nested structures
        df = df.select(
            F.col("season"),
            F.col("round"),
            F.col("raceName"),
            F.col("date"),
            F.col("time"),
            F.to_timestamp(F.col("ingestion_timestamp")).alias("ingestion_timestamp"),
            F.col("url"),
            F.col("Circuit").alias("circuit"),
            F.explode("Results").alias("result")
        )

        # 2. Flatten nested structures
        df = df.select(
            "*",
            F.col("circuit.circuitId").alias("circuit_id"),
            F.col("circuit.circuitName").alias("circuit_name"),
            F.col("circuit.Location.lat").alias("circuit_lat"),
            F.col("circuit.Location.long").alias("circuit_long"),
            F.col("circuit.Location.locality").alias("circuit_locality"),
            F.col("circuit.Location.country").alias("circuit_country"),
            F.col("result.Constructor.constructorId").alias("constructor_id"),
            F.col("result.Constructor.name").alias("constructor_name"),
            F.col("result.Driver.driverId").alias("driver_id"),
            F.col("result.Driver.givenName").alias("driver_given_name"),
            F.col("result.Driver.familyName").alias("driver_family_name"),
            F.col("result.position").alias("position"),
            F.col("result.points").alias("points"),
            F.col("result.grid").alias("grid"),
            F.col("result.laps").alias("laps"),
            F.col("result.status").alias("status"),
            F.col("result.Time.time").alias("finish_time")
        ).drop("circuit", "result")

        # 3. Data type conversions and standardization
        df = df.withColumn("race_timestamp",
                          F.to_timestamp(
                              F.concat(F.col("date"), F.lit(" "), F.col("time")),
                              "yyyy-MM-dd HH:mm:ssX"
                          ))

        # 4. Add computed columns
        df = df.withColumn("driver_full_name",
                          F.concat(F.col("driver_given_name"), F.lit(" "), F.col("driver_family_name")))

        # 5. Generate hash key for change detection
        columns_for_hash = ["season", "round", "driver_id", "constructor_id", "position"]
        df = df.withColumn("row_hash",
                          F.sha2(F.concat_ws("|", *[F.col(c) for c in columns_for_hash]), 256))

        # 6. Add data quality columns
        df = self.add_data_quality_checks(df)

        # 7. Add metadata columns
        df = df.withColumn("processed_timestamp", F.current_timestamp())
        df = df.withColumn("silver_batch_id", F.expr("uuid()"))

        return df

    def add_data_quality_checks(self, df):
        """Add data quality check columns"""
        # Define quality checks
        df = df.withColumn("is_valid_position",
                          (F.col("position").isNotNull() & (F.col("position") >= 1)))

        df = df.withColumn("is_valid_points",
                          (F.col("points").isNotNull() & (F.col("points") >= 0)))

        df = df.withColumn("is_valid_grid",
                          (F.col("grid").isNotNull() & (F.col("grid") >= 0)))

        df = df.withColumn("is_valid_date",
                          F.col("race_timestamp").isNotNull())

        # Combine all checks
        df = df.withColumn("is_valid_record",
                          F.col("is_valid_position") &
                          F.col("is_valid_points") &
                          F.col("is_valid_grid") &
                          F.col("is_valid_date"))

        return df

    def write_to_silver(self, df):
        """Write processed data to silver layer"""
        if df is None or df.count() == 0:
            return

        # Write main silver table
        (df.write
         .mode("append")
         .partitionBy("season")
         .format("parquet")
         .save(self.silver_path))

        # Write quality metrics
        quality_metrics = self.calculate_quality_metrics(df)
        (quality_metrics.write
         .mode("append")
         .format("parquet")
         .save(f"{self.silver_path}_metrics"))

        # Update checkpoint
        self.update_checkpoint(df)

    def calculate_quality_metrics(self, df):
        """Calculate quality metrics for the batch"""
        metrics = []

        # Calculate null percentages
        for column in df.columns:
            null_count = df.filter(F.col(column).isNull()).count()
            total_count = df.count()
            null_percentage = (null_count / total_count) * 100 if total_count > 0 else 0
            metrics.append({
                "metric_name": f"null_percentage_{column}",
                "metric_value": null_percentage,
                "batch_id": df.select(F.first("silver_batch_id")).collect()[0][0],
                "calculated_at": datetime.now().isoformat()
            })

        # Add other metrics
        metrics.extend([
            {
                "metric_name": "total_records",
                "metric_value": df.count(),
                "batch_id": df.select(F.first("silver_batch_id")).collect()[0][0],
                "calculated_at": datetime.now().isoformat()
            },
            {
                "metric_name": "invalid_records_percentage",
                "metric_value": (df.filter(~F.col("is_valid_record")).count() / df.count()) * 100,
                "batch_id": df.select(F.first("silver_batch_id")).collect()[0][0],
                "calculated_at": datetime.now().isoformat()
            }
        ])

        return self.spark.createDataFrame(metrics)

    def update_checkpoint(self, df):
        """Update the checkpoint with latest processed timestamp"""
        checkpoint_data = [{
            "processed_timestamp": df.select(F.max("ingestion_timestamp")).collect()[0][0],
            "silver_batch_id": df.select(F.first("silver_batch_id")).collect()[0][0],
            "record_count": df.count(),
            "checkpoint_timestamp": datetime.now().isoformat()
        }]

        (self.spark.createDataFrame(checkpoint_data)
         .write
         .mode("append")
         .parquet(self.checkpoint_path))

    def process(self):
        """Main processing method"""
        try:
            print("Starting Silver Layer Processing...")

            # Process bronze to silver
            silver_df = self.process_bronze_data()

            # Write to silver if we have data
            if silver_df is not None:
                print(f"Writing {silver_df.count()} records to silver layer...")
                self.write_to_silver(silver_df)
                print("Silver Layer Processing Complete!")
            else:
                print("No new data to process")

        except Exception as e:
            print(f"Error in Silver Layer Processing: {str(e)}")
            raise

# Initialize Spark Session (simplified for Colab)
spark = SparkSession.builder \
    .appName("F1SilverLayer") \
    .getOrCreate()

# Create and run silver layer
silver_layer = F1SilverLayer(
    spark,
    bronze_path="/content/sample_data/bronze/",
    silver_path="/content/sample_data/silver/",
    checkpoint_path="/content/sample_data/silver_checkpoint/"
)

# Process data
silver_layer.process()